Yash Katariya
2fc64bee13
Change the axis_resources
argument of with_sharding_constraint
to shardings
to match pjit
and jit
.
...
PiperOrigin-RevId: 509275107
2023-02-13 10:53:57 -08:00
Jake VanderPlas
58323d5b40
jax.numpy reductions: better validation of initial value
2023-02-13 08:43:25 -08:00
Yash Katariya
6caaffc20c
Add in_shardings and out_shardings argument to pjit and jit to start deprecating in_axis_resources and out_axis_resources.
...
PiperOrigin-RevId: 508934327
2023-02-11 15:30:14 -08:00
Peter Hawkins
612a940160
Minimize the set of names exported from jax.experimental.pjit.
...
PiperOrigin-RevId: 508889911
2023-02-11 07:37:32 -08:00
Yash Katariya
9316188b3a
[Rollback] Convert _arrays to return PyArray instead of PyBuffer.
...
PiperOrigin-RevId: 508827908
2023-02-10 21:36:56 -08:00
Skye Wanderman-Milne
e54858522c
Add back loading TPU plugin for older jaxlib versions.
...
This was removed in 668b82d529
.
PiperOrigin-RevId: 508777939
2023-02-10 16:16:20 -08:00
jax authors
fc507f2ebe
Merge pull request #14418 from mattjj:vmap-spmd-axis-name-tuples
...
PiperOrigin-RevId: 508777043
2023-02-10 16:08:32 -08:00
Yash Katariya
0d07372995
Point to the exact primitive name nested under jit/pjit instead of mentioning all possible ones.
...
PiperOrigin-RevId: 508770290
2023-02-10 15:40:25 -08:00
Parker Schuh
568a93bcd1
Convert _arrays to return PyArray instead of PyBuffer.
...
PiperOrigin-RevId: 508769390
2023-02-10 15:32:57 -08:00
Matthew Johnson
9538bc3e73
generalize vmap spmd_axis_name to accept tuples of axis names
...
This brings the argument more in line with what can appear as positional
arguments to the PartitionSpec constructor.
2023-02-10 15:25:23 -08:00
jax authors
dc6bf9b725
Merge pull request #14408 from lucashofer:scipy_spence
...
PiperOrigin-RevId: 508756972
2023-02-10 14:36:15 -08:00
Yash Katariya
1526c3e20c
Improve the error message which is raised from _get_and_check_device_assignment
.
...
Before:
```
ValueError: Devices of all `Array` inputs and outputs should be the same. Got array device ids [0] on platform CPU and another array's device ids [0, 1, 2, 3] on platform CPU
```
After:
```
ValueError: Received incompatible devices for jitted computation. Got argument inp of ArrayPjitTest.test_jit_with_sharding_constraint_committed_inp_error.<locals>.sharded_inp with bfloat16[8,2] and device ids [0] on platform CPU and with_sharding_constraint or nested pjit or shard_map with device ids [0, 1, 2, 3] on platform CPU at jax/tests/pjit_test.py:2509 (sharded_inp)
```
PiperOrigin-RevId: 508746961
2023-02-10 13:54:15 -08:00
Lucas Hofer
4636276214
added scipy special spence
...
added dtype to arrays in the _spence_poly function
2023-02-10 20:33:47 +00:00
jax authors
57900d7ef2
Merge pull request #14364 from jakevdp:fix-tril-indices
...
PiperOrigin-RevId: 508723970
2023-02-10 12:25:06 -08:00
Jake VanderPlas
60256df668
[typing] define additional methods & properties on jax.Array
...
These are the methods that are only valid for actual materialized arrays (i.e. not Tracers)
In order to simplify the experience for users, we want to maintain only a single jax.Array
type, so we define all methods here and raise explicit errors on Tracer instances.
2023-02-10 09:42:32 -08:00
Peter Hawkins
54ff78dbde
Deprecate jax.interpreters.xla.Device and jax.interpreters.xla.DeviceArray.
...
PiperOrigin-RevId: 508502470
2023-02-09 16:11:48 -08:00
Roy Frostig
1c84e4a753
migrate internal dependencies from jax.interpreters.batching
to jax._src.interpreters.batching
...
... in preparation for paring down `jax.interpreters.batching`'s exported symbols.
PiperOrigin-RevId: 508487887
2023-02-09 15:11:57 -08:00
Jieying Luo
668b82d529
[PJRT C API] Register a backend factory for every PJRT plugin set in PJRT_NAMES_AND_LIBRARY_PATHS.
...
Loading TPU PJRT plugin is moved to make_tpu_client.
This change is based on https://github.com/google/jax/pull/14011 .
PiperOrigin-RevId: 508477737
2023-02-09 14:33:46 -08:00
Peter Hawkins
0c14e9ab49
Change jax.ad, jax.xla, jax.pxla to point to the shims instead of the internal modules.
...
Don't hide _deprecations in shim modules, since it's handy for users to override deprecations locally, e.g., to verify there are no remaining users.
Fix some overly-strict type annotations.
PiperOrigin-RevId: 508461199
2023-02-09 13:31:40 -08:00
Matthew Johnson
a964dc3b9a
simpler pretty-print for pjit, tweak custom pp rule signature
2023-02-09 12:45:51 -08:00
Peter Hawkins
8268cd562d
Add infrastructure for managing deprecations.
...
Use it to deprecate jax.experimental.PartitionSpec, jax.interpreters.pxla.PartitionSpec, jax.interpreters.pxla.Mesh.
PiperOrigin-RevId: 508349776
2023-02-09 05:48:40 -08:00
jax authors
ccb974a150
Merge pull request #14370 from jakevdp:argpartition-impl
...
PiperOrigin-RevId: 508194466
2023-02-08 15:10:50 -08:00
Peter Hawkins
a28b01243b
Move contents of jax.monitoring to jax._src.monitoring.
...
PiperOrigin-RevId: 508191560
2023-02-08 15:03:22 -08:00
Yash Katariya
7350f00acd
Remove jax_experimental_subjaxpr_lowering_cache
since it was only for jit
and was False
by default. Now that jit/pjit are merged, this cache is not needed since pjit does the caching and we get it for free.
...
PiperOrigin-RevId: 508191408
2023-02-08 14:55:56 -08:00
Jake VanderPlas
4fbaee5920
Implement jax.numpy.argpartition
2023-02-08 14:41:39 -08:00
Peter Hawkins
cc8d7fae32
Move jax.interpreters.mlir to jax._src.interpreters.mlir.
...
Replace jax.interpreters.mlir with a shim that re-exports names that are likely to be used externally.
PiperOrigin-RevId: 508187063
2023-02-08 14:39:01 -08:00
jax authors
3e349c7bed
Merge pull request #14361 from jakevdp:doc-topk
...
PiperOrigin-RevId: 508181335
2023-02-08 14:19:01 -08:00
Yash Katariya
6ec9082cf5
Default jax_jit_pjit_api_merge
to True. This means that the implementation of jit and pjit have been merged but they still remain separate APIs due to the semantic difference of how they behave under the Mesh
context manager.
...
This changes the internals of JAX without affecting any public API.
Before, `jit` was a final style primitive. This means that the creation
of jaxpr was delayed as much as possible and transformations were stacked
on top of each other. With the `jit`-`pjit` implementation merge, `jit`
becomes an initial style primitive which means that we trace to jaxpr
as early as possible. For more information see [this section in autodidax](https://jax.readthedocs.io/en/latest/autodidax.html#on-the-fly-final-style-and-staged-initial-style-processing ).
Moving to initial style should simplify JAX's internals and make
development of features like dynamic shapes, etc easier.
PiperOrigin-RevId: 508143501
2023-02-08 11:55:48 -08:00
Jake VanderPlas
794557d349
tril_indices/triu_indices: fix call signature & add type annotations
2023-02-08 11:19:06 -08:00
Jake VanderPlas
3c6183498a
lax.top_k: improve documentation and errors on invalid values
2023-02-08 11:07:56 -08:00
jax authors
4358b803e9
Merge pull request #14355 from jakevdp:tril-indices
...
PiperOrigin-RevId: 508119785
2023-02-08 10:35:10 -08:00
Yash Katariya
7b1128fdc4
Use jnp.arange to break the pjit cache (when jit and pjit are merged) because pytest runs tests non-hermetically.
...
PiperOrigin-RevId: 508114498
2023-02-08 10:17:37 -08:00
Jake VanderPlas
a76a024548
tril/triu_indices: compute arrays at runtime
2023-02-08 09:52:41 -08:00
Roy Frostig
55c2b6dad6
move jax.interpreters.batching
to jax._src.interpreters.batching
...
Re-export roughly all of the same symbols via `jax.interpreters.batching` for now.
PiperOrigin-RevId: 508107044
2023-02-08 09:51:00 -08:00
Skye Wanderman-Milne
eb13c053e9
Add option to run tests with persistent compilation cache enabled.
...
This can help us get a lot more coverage of the compilation cache, since all compiles will trigger it, instead of having to write explicit compilation cache tests.
PiperOrigin-RevId: 507898535
2023-02-07 15:15:31 -08:00
Peter Hawkins
6860cb8d2a
Move jax.interpreters.xla to jax._src.interpreters.xla.
...
Replace jax.interpreters.xla with a shim that re-exports names that are likely to be used externally.
PiperOrigin-RevId: 507895040
2023-02-07 15:01:32 -08:00
Peter Hawkins
98b75cf27b
Prune accidental exports from jax.interpreters.pxla.
...
These imports do not appear to have users outside JAX itself.
PiperOrigin-RevId: 507835295
2023-02-07 11:16:42 -08:00
jax authors
d4f422f608
Merge pull request #14303 from carlosgmartin:rankdata
...
PiperOrigin-RevId: 507805953
2023-02-07 09:37:04 -08:00
carlosgmartin
8251957025
Added scipy.stats.rankdata
2023-02-07 12:07:00 -05:00
Roy Frostig
219723c738
migrate internal dependencies from jax.interpreters.ad
to jax._src.interpreters.ad
...
... in preparation for paring down `jax.interpreters.ad`'s exported symbols.
Includes some import fixups along the way.
PiperOrigin-RevId: 507684262
2023-02-06 22:52:36 -08:00
Yash Katariya
c252162821
Make pjit's cache global just like jit
's cache. This will allow cache hits in C++ when pjit(f)(jnp.arange(3.))
is executed twice.
...
Also includes Peter's change to fix the cache hit behavior which was broken at HEAD with jit.
PiperOrigin-RevId: 507662634
2023-02-06 20:35:26 -08:00
Skye Wanderman-Milne
6cef0873e8
Don't write executables with host callbacks to persistent compilation cache.
...
The persistent compilation cache can't de/serialize the callback functions (yet?).
PiperOrigin-RevId: 507628297
2023-02-06 17:37:32 -08:00
Peter Hawkins
08ff7f4ea9
Prune accidentally exported names from jax.interpreters.ad.
...
PiperOrigin-RevId: 507584433
2023-02-06 14:36:44 -08:00
Peter Hawkins
38a59a313b
Move jax.interpreters.pxla to jax._src.interpreters.pxla.
...
Make jax.interpreters.pxla a shim that at the moment re-exports everything in the implementation, with the goal of reducing it over time.
PiperOrigin-RevId: 507584264
2023-02-06 14:29:10 -08:00
Peter Hawkins
3d9ae6b467
Add a .cost_analysis() on lowered but uncompiled computations.
...
Allows users to call XLA's HLO cost analysis without using internal APIs. In practice plenty of users appear to be doing this using non-public APIs, so we may as well offer a supported API for it.
PiperOrigin-RevId: 507560058
2023-02-06 12:57:57 -08:00
Yash Katariya
8a69444ff9
Bump minimum jaxlib_version to 0.4.2 i.e xla_extension_version == 119 and mlir_api_version == 43
...
PiperOrigin-RevId: 507520956
2023-02-06 10:37:33 -08:00
jax authors
63e0e0fdb6
Merge pull request #14291 from sharadmv:fix-checkify-caching
...
PiperOrigin-RevId: 507504176
2023-02-06 09:39:07 -08:00
Yash Katariya
a30ba83db2
Fix the latest jax jaxlib on pypi failure
...
PiperOrigin-RevId: 507208172
2023-02-04 20:16:33 -08:00
Yash Katariya
973bdb203b
Copy the jit docs and paste it inside the new jit fork.
...
PiperOrigin-RevId: 507161252
2023-02-04 12:34:35 -08:00
Sharad Vikram
c231171fb6
Fix checkify caching with nested call primitives
2023-02-03 23:28:37 -08:00