14790 Commits

Author SHA1 Message Date
jax authors
357b48d29a Merge pull request #14391 from ROCmSoftwarePlatform:rocm_switch_to_rocm54
PiperOrigin-RevId: 508497281
2023-02-09 15:50:49 -08:00
Roy Frostig
1c84e4a753 migrate internal dependencies from jax.interpreters.batching to jax._src.interpreters.batching
... in preparation for paring down `jax.interpreters.batching`'s exported symbols.

PiperOrigin-RevId: 508487887
2023-02-09 15:11:57 -08:00
jax authors
12dc73dc6e Merge pull request #14388 from jakevdp:bcsr-todense-ad
PiperOrigin-RevId: 508477843
2023-02-09 14:41:41 -08:00
Jieying Luo
668b82d529 [PJRT C API] Register a backend factory for every PJRT plugin set in PJRT_NAMES_AND_LIBRARY_PATHS.
Loading TPU PJRT plugin is moved to make_tpu_client.

This change is based on https://github.com/google/jax/pull/14011.

PiperOrigin-RevId: 508477737
2023-02-09 14:33:46 -08:00
Jake VanderPlas
7651866b1d [sparse] implement autodiff rules for bcsr primitives 2023-02-09 14:19:22 -08:00
Jake VanderPlas
15c9bca67f [sparse] add cusparse lowering for simplest cases of bcsr_dot_general
PiperOrigin-RevId: 508473938
2023-02-09 14:18:44 -08:00
jax authors
253cd4d9d1 Merge pull request #14387 from ROCmSoftwarePlatform:rocm_reenable_dirichlet_test
PiperOrigin-RevId: 508466026
2023-02-09 13:50:06 -08:00
Peter Hawkins
88cc254f2c [JAX] Replace uses of jax.interpreters.pxla.ShardedDeviceArray with jax.Array.
PiperOrigin-RevId: 508463147
2023-02-09 13:39:41 -08:00
Peter Hawkins
0c14e9ab49 Change jax.ad, jax.xla, jax.pxla to point to the shims instead of the internal modules.
Don't hide _deprecations in shim modules, since it's handy for users to override deprecations locally, e.g., to verify there are no remaining users.

Fix some overly-strict type annotations.

PiperOrigin-RevId: 508461199
2023-02-09 13:31:40 -08:00
jax authors
adcceb228f Merge pull request #14384 from mattjj:pjit-pretty-print
PiperOrigin-RevId: 508454299
2023-02-09 13:04:58 -08:00
Matthew Johnson
a964dc3b9a simpler pretty-print for pjit, tweak custom pp rule signature 2023-02-09 12:45:51 -08:00
Rahul Batra
7d0d9b706e [ROCm]: Re-enable Dirichlet Tests on ROCm 2023-02-09 20:19:07 +00:00
Rahul Batra
023226e181 [ROCm]: Move dockerfile to ROCm5.4 2023-02-09 20:08:35 +00:00
Peter Hawkins
8268cd562d Add infrastructure for managing deprecations.
Use it to deprecate jax.experimental.PartitionSpec, jax.interpreters.pxla.PartitionSpec, jax.interpreters.pxla.Mesh.

PiperOrigin-RevId: 508349776
2023-02-09 05:48:40 -08:00
jax authors
3f8cb0a7c9 Merge pull request #14379 from mattjj:shmap-vmap-spmd-axis-name
PiperOrigin-RevId: 508292029
2023-02-09 00:14:30 -08:00
Matthew Johnson
6fb3ace5d0 [shard-map] add vmap spmd_axis_name support, fix vmap rule bug 2023-02-08 23:54:28 -08:00
jax authors
bd7c227e96 Merge pull request #14373 from mattjj:shmap-check-rep-false
PiperOrigin-RevId: 508219490
2023-02-08 16:49:29 -08:00
Matthew Johnson
1a03f34383 [shard-map] if check_rep=False, don't call rep rules in eager 2023-02-08 15:42:35 -08:00
jax authors
ccb974a150 Merge pull request #14370 from jakevdp:argpartition-impl
PiperOrigin-RevId: 508194466
2023-02-08 15:10:50 -08:00
Peter Hawkins
a28b01243b Move contents of jax.monitoring to jax._src.monitoring.
PiperOrigin-RevId: 508191560
2023-02-08 15:03:22 -08:00
Yash Katariya
7350f00acd Remove jax_experimental_subjaxpr_lowering_cache since it was only for jit and was False by default. Now that jit/pjit are merged, this cache is not needed since pjit does the caching and we get it for free.
PiperOrigin-RevId: 508191408
2023-02-08 14:55:56 -08:00
Jake VanderPlas
4fbaee5920 Implement jax.numpy.argpartition 2023-02-08 14:41:39 -08:00
Peter Hawkins
cc8d7fae32 Move jax.interpreters.mlir to jax._src.interpreters.mlir.
Replace jax.interpreters.mlir with a shim that re-exports names that are likely to be used externally.

PiperOrigin-RevId: 508187063
2023-02-08 14:39:01 -08:00
jax authors
3e349c7bed Merge pull request #14361 from jakevdp:doc-topk
PiperOrigin-RevId: 508181335
2023-02-08 14:19:01 -08:00
Yash Katariya
e4d551a217 Remove the doctest skip now that jit and pjit have been merged
PiperOrigin-RevId: 508162840
2023-02-08 13:09:53 -08:00
jax authors
1254d44dbd Remove silent data corruption runtime flags from persistent cache key.
These flags have no effect on the compiled executable, just the runtime execution.

PiperOrigin-RevId: 508152877
2023-02-08 12:31:27 -08:00
Ashish Shenoy
f71a55c554 Rename tensorflow core target variable to tensorflow_core
PiperOrigin-RevId: 508148106
2023-02-08 12:11:59 -08:00
Yash Katariya
6ec9082cf5 Default jax_jit_pjit_api_merge to True. This means that the implementation of jit and pjit have been merged but they still remain separate APIs due to the semantic difference of how they behave under the Mesh context manager.
This changes the internals of JAX without affecting any public API.

Before, `jit` was a final style primitive. This means that the creation
of jaxpr was delayed as much as possible and transformations were stacked
on top of each other. With the `jit`-`pjit` implementation merge, `jit`
becomes an initial style primitive which means that we trace to jaxpr
as early as possible. For more information see [this section in autodidax](https://jax.readthedocs.io/en/latest/autodidax.html#on-the-fly-final-style-and-staged-initial-style-processing).

Moving to initial style should simplify JAX's internals and make
development of features like dynamic shapes, etc easier.

PiperOrigin-RevId: 508143501
2023-02-08 11:55:48 -08:00
jax authors
9a1f9b1ef8 Merge pull request #14362 from mattjj:shmap-remat
PiperOrigin-RevId: 508139783
2023-02-08 11:42:12 -08:00
Jake VanderPlas
794557d349 tril_indices/triu_indices: fix call signature & add type annotations 2023-02-08 11:19:06 -08:00
Matthew Johnson
58d3f552d7 [shard-map] add remat support, very basic test 2023-02-08 11:15:38 -08:00
Jake VanderPlas
3c6183498a lax.top_k: improve documentation and errors on invalid values 2023-02-08 11:07:56 -08:00
jax authors
4844e3f85c Merge pull request #14357 from skye:version
PiperOrigin-RevId: 508119911
2023-02-08 10:42:48 -08:00
jax authors
4358b803e9 Merge pull request #14355 from jakevdp:tril-indices
PiperOrigin-RevId: 508119785
2023-02-08 10:35:10 -08:00
Yash Katariya
7b1128fdc4 Use jnp.arange to break the pjit cache (when jit and pjit are merged) because pytest runs tests non-hermetically.
PiperOrigin-RevId: 508114498
2023-02-08 10:17:37 -08:00
Skye Wanderman-Milne
21f12183bf Post 0.4.3 release updates 2023-02-08 10:08:59 -08:00
Jake VanderPlas
a76a024548 tril/triu_indices: compute arrays at runtime 2023-02-08 09:52:41 -08:00
Roy Frostig
55c2b6dad6 move jax.interpreters.batching to jax._src.interpreters.batching
Re-export roughly all of the same symbols via `jax.interpreters.batching` for now.

PiperOrigin-RevId: 508107044
2023-02-08 09:51:00 -08:00
jax authors
85b0c88490 Merge pull request #14353 from google:jaxpr_fix
PiperOrigin-RevId: 508098163
2023-02-08 09:15:57 -08:00
yashkatariya
2cfec044bf Fix the jaxpr after jit-pjit merge 2023-02-08 09:12:01 -08:00
yashkatariya
d3eef935f7 Fix the jaxpr after jit-pjit merge 2023-02-08 08:52:57 -08:00
Parker Schuh
c3e6d5cb2a Remove some differences between jit and pjit.
- MaybeCollectGarbage
- The recursive check.
- DevicePut for np arrays and scalars when device_count == 1.

PiperOrigin-RevId: 507972281
2023-02-07 21:33:07 -08:00
jax authors
7783a5e129 Merge pull request #14343 from skye:cache_options_check
PiperOrigin-RevId: 507933881
jax-v0.4.3 jaxlib-v0.4.3 jax-v0.4.3-rc
2023-02-07 17:49:31 -08:00
Skye Wanderman-Milne
1228cbd26b Change the executable_build_options check in compilation_cache.py to be more robust.
Prior to this change, the check would spuriously fire on Python 3.11
because it added a default `__getstate__` method to all objects. This
change makes it so we only look at public fields and methods.
2023-02-08 01:22:14 +00:00
jax authors
a46f31bca8 Merge pull request #14342 from skye:version
PiperOrigin-RevId: 507907767
2023-02-07 15:51:49 -08:00
Skye Wanderman-Milne
8ab158574d Update WORKSPACE and setup.py for jax/jaxlib 0.4.3 release 2023-02-07 15:45:28 -08:00
Skye Wanderman-Milne
eb13c053e9 Add option to run tests with persistent compilation cache enabled.
This can help us get a lot more coverage of the compilation cache, since all compiles will trigger it, instead of having to write explicit compilation cache tests.

PiperOrigin-RevId: 507898535
2023-02-07 15:15:31 -08:00
Peter Hawkins
6860cb8d2a Move jax.interpreters.xla to jax._src.interpreters.xla.
Replace jax.interpreters.xla with a shim that re-exports names that are likely to be used externally.

PiperOrigin-RevId: 507895040
2023-02-07 15:01:32 -08:00
jax authors
9c827fbd9a Merge pull request #14340 from ROCmSoftwarePlatform:rocm_reenable_linalg_sparse_tests
PiperOrigin-RevId: 507886628
2023-02-07 14:30:37 -08:00
Rahul Batra
01a10a1d06 [ROCm] Re-enable some linalg and sparse tests 2023-02-07 22:05:14 +00:00