jax authors
357b48d29a
Merge pull request #14391 from ROCmSoftwarePlatform:rocm_switch_to_rocm54
...
PiperOrigin-RevId: 508497281
2023-02-09 15:50:49 -08:00
Roy Frostig
1c84e4a753
migrate internal dependencies from jax.interpreters.batching
to jax._src.interpreters.batching
...
... in preparation for paring down `jax.interpreters.batching`'s exported symbols.
PiperOrigin-RevId: 508487887
2023-02-09 15:11:57 -08:00
jax authors
12dc73dc6e
Merge pull request #14388 from jakevdp:bcsr-todense-ad
...
PiperOrigin-RevId: 508477843
2023-02-09 14:41:41 -08:00
Jieying Luo
668b82d529
[PJRT C API] Register a backend factory for every PJRT plugin set in PJRT_NAMES_AND_LIBRARY_PATHS.
...
Loading TPU PJRT plugin is moved to make_tpu_client.
This change is based on https://github.com/google/jax/pull/14011 .
PiperOrigin-RevId: 508477737
2023-02-09 14:33:46 -08:00
Jake VanderPlas
7651866b1d
[sparse] implement autodiff rules for bcsr primitives
2023-02-09 14:19:22 -08:00
Jake VanderPlas
15c9bca67f
[sparse] add cusparse lowering for simplest cases of bcsr_dot_general
...
PiperOrigin-RevId: 508473938
2023-02-09 14:18:44 -08:00
jax authors
253cd4d9d1
Merge pull request #14387 from ROCmSoftwarePlatform:rocm_reenable_dirichlet_test
...
PiperOrigin-RevId: 508466026
2023-02-09 13:50:06 -08:00
Peter Hawkins
88cc254f2c
[JAX] Replace uses of jax.interpreters.pxla.ShardedDeviceArray with jax.Array.
...
PiperOrigin-RevId: 508463147
2023-02-09 13:39:41 -08:00
Peter Hawkins
0c14e9ab49
Change jax.ad, jax.xla, jax.pxla to point to the shims instead of the internal modules.
...
Don't hide _deprecations in shim modules, since it's handy for users to override deprecations locally, e.g., to verify there are no remaining users.
Fix some overly-strict type annotations.
PiperOrigin-RevId: 508461199
2023-02-09 13:31:40 -08:00
jax authors
adcceb228f
Merge pull request #14384 from mattjj:pjit-pretty-print
...
PiperOrigin-RevId: 508454299
2023-02-09 13:04:58 -08:00
Matthew Johnson
a964dc3b9a
simpler pretty-print for pjit, tweak custom pp rule signature
2023-02-09 12:45:51 -08:00
Rahul Batra
7d0d9b706e
[ROCm]: Re-enable Dirichlet Tests on ROCm
2023-02-09 20:19:07 +00:00
Rahul Batra
023226e181
[ROCm]: Move dockerfile to ROCm5.4
2023-02-09 20:08:35 +00:00
Peter Hawkins
8268cd562d
Add infrastructure for managing deprecations.
...
Use it to deprecate jax.experimental.PartitionSpec, jax.interpreters.pxla.PartitionSpec, jax.interpreters.pxla.Mesh.
PiperOrigin-RevId: 508349776
2023-02-09 05:48:40 -08:00
jax authors
3f8cb0a7c9
Merge pull request #14379 from mattjj:shmap-vmap-spmd-axis-name
...
PiperOrigin-RevId: 508292029
2023-02-09 00:14:30 -08:00
Matthew Johnson
6fb3ace5d0
[shard-map] add vmap spmd_axis_name support, fix vmap rule bug
2023-02-08 23:54:28 -08:00
jax authors
bd7c227e96
Merge pull request #14373 from mattjj:shmap-check-rep-false
...
PiperOrigin-RevId: 508219490
2023-02-08 16:49:29 -08:00
Matthew Johnson
1a03f34383
[shard-map] if check_rep=False, don't call rep rules in eager
2023-02-08 15:42:35 -08:00
jax authors
ccb974a150
Merge pull request #14370 from jakevdp:argpartition-impl
...
PiperOrigin-RevId: 508194466
2023-02-08 15:10:50 -08:00
Peter Hawkins
a28b01243b
Move contents of jax.monitoring to jax._src.monitoring.
...
PiperOrigin-RevId: 508191560
2023-02-08 15:03:22 -08:00
Yash Katariya
7350f00acd
Remove jax_experimental_subjaxpr_lowering_cache
since it was only for jit
and was False
by default. Now that jit/pjit are merged, this cache is not needed since pjit does the caching and we get it for free.
...
PiperOrigin-RevId: 508191408
2023-02-08 14:55:56 -08:00
Jake VanderPlas
4fbaee5920
Implement jax.numpy.argpartition
2023-02-08 14:41:39 -08:00
Peter Hawkins
cc8d7fae32
Move jax.interpreters.mlir to jax._src.interpreters.mlir.
...
Replace jax.interpreters.mlir with a shim that re-exports names that are likely to be used externally.
PiperOrigin-RevId: 508187063
2023-02-08 14:39:01 -08:00
jax authors
3e349c7bed
Merge pull request #14361 from jakevdp:doc-topk
...
PiperOrigin-RevId: 508181335
2023-02-08 14:19:01 -08:00
Yash Katariya
e4d551a217
Remove the doctest skip now that jit and pjit have been merged
...
PiperOrigin-RevId: 508162840
2023-02-08 13:09:53 -08:00
jax authors
1254d44dbd
Remove silent data corruption runtime flags from persistent cache key.
...
These flags have no effect on the compiled executable, just the runtime execution.
PiperOrigin-RevId: 508152877
2023-02-08 12:31:27 -08:00
Ashish Shenoy
f71a55c554
Rename tensorflow core target variable to tensorflow_core
...
PiperOrigin-RevId: 508148106
2023-02-08 12:11:59 -08:00
Yash Katariya
6ec9082cf5
Default jax_jit_pjit_api_merge
to True. This means that the implementation of jit and pjit have been merged but they still remain separate APIs due to the semantic difference of how they behave under the Mesh
context manager.
...
This changes the internals of JAX without affecting any public API.
Before, `jit` was a final style primitive. This means that the creation
of jaxpr was delayed as much as possible and transformations were stacked
on top of each other. With the `jit`-`pjit` implementation merge, `jit`
becomes an initial style primitive which means that we trace to jaxpr
as early as possible. For more information see [this section in autodidax](https://jax.readthedocs.io/en/latest/autodidax.html#on-the-fly-final-style-and-staged-initial-style-processing ).
Moving to initial style should simplify JAX's internals and make
development of features like dynamic shapes, etc easier.
PiperOrigin-RevId: 508143501
2023-02-08 11:55:48 -08:00
jax authors
9a1f9b1ef8
Merge pull request #14362 from mattjj:shmap-remat
...
PiperOrigin-RevId: 508139783
2023-02-08 11:42:12 -08:00
Jake VanderPlas
794557d349
tril_indices/triu_indices: fix call signature & add type annotations
2023-02-08 11:19:06 -08:00
Matthew Johnson
58d3f552d7
[shard-map] add remat support, very basic test
2023-02-08 11:15:38 -08:00
Jake VanderPlas
3c6183498a
lax.top_k: improve documentation and errors on invalid values
2023-02-08 11:07:56 -08:00
jax authors
4844e3f85c
Merge pull request #14357 from skye:version
...
PiperOrigin-RevId: 508119911
2023-02-08 10:42:48 -08:00
jax authors
4358b803e9
Merge pull request #14355 from jakevdp:tril-indices
...
PiperOrigin-RevId: 508119785
2023-02-08 10:35:10 -08:00
Yash Katariya
7b1128fdc4
Use jnp.arange to break the pjit cache (when jit and pjit are merged) because pytest runs tests non-hermetically.
...
PiperOrigin-RevId: 508114498
2023-02-08 10:17:37 -08:00
Skye Wanderman-Milne
21f12183bf
Post 0.4.3 release updates
2023-02-08 10:08:59 -08:00
Jake VanderPlas
a76a024548
tril/triu_indices: compute arrays at runtime
2023-02-08 09:52:41 -08:00
Roy Frostig
55c2b6dad6
move jax.interpreters.batching
to jax._src.interpreters.batching
...
Re-export roughly all of the same symbols via `jax.interpreters.batching` for now.
PiperOrigin-RevId: 508107044
2023-02-08 09:51:00 -08:00
jax authors
85b0c88490
Merge pull request #14353 from google:jaxpr_fix
...
PiperOrigin-RevId: 508098163
2023-02-08 09:15:57 -08:00
yashkatariya
2cfec044bf
Fix the jaxpr after jit-pjit merge
2023-02-08 09:12:01 -08:00
yashkatariya
d3eef935f7
Fix the jaxpr after jit-pjit merge
2023-02-08 08:52:57 -08:00
Parker Schuh
c3e6d5cb2a
Remove some differences between jit and pjit.
...
- MaybeCollectGarbage
- The recursive check.
- DevicePut for np arrays and scalars when device_count == 1.
PiperOrigin-RevId: 507972281
2023-02-07 21:33:07 -08:00
jax authors
7783a5e129
Merge pull request #14343 from skye:cache_options_check
...
PiperOrigin-RevId: 507933881
jax-v0.4.3
jaxlib-v0.4.3
jax-v0.4.3-rc
2023-02-07 17:49:31 -08:00
Skye Wanderman-Milne
1228cbd26b
Change the executable_build_options check in compilation_cache.py to be more robust.
...
Prior to this change, the check would spuriously fire on Python 3.11
because it added a default `__getstate__` method to all objects. This
change makes it so we only look at public fields and methods.
2023-02-08 01:22:14 +00:00
jax authors
a46f31bca8
Merge pull request #14342 from skye:version
...
PiperOrigin-RevId: 507907767
2023-02-07 15:51:49 -08:00
Skye Wanderman-Milne
8ab158574d
Update WORKSPACE and setup.py for jax/jaxlib 0.4.3 release
2023-02-07 15:45:28 -08:00
Skye Wanderman-Milne
eb13c053e9
Add option to run tests with persistent compilation cache enabled.
...
This can help us get a lot more coverage of the compilation cache, since all compiles will trigger it, instead of having to write explicit compilation cache tests.
PiperOrigin-RevId: 507898535
2023-02-07 15:15:31 -08:00
Peter Hawkins
6860cb8d2a
Move jax.interpreters.xla to jax._src.interpreters.xla.
...
Replace jax.interpreters.xla with a shim that re-exports names that are likely to be used externally.
PiperOrigin-RevId: 507895040
2023-02-07 15:01:32 -08:00
jax authors
9c827fbd9a
Merge pull request #14340 from ROCmSoftwarePlatform:rocm_reenable_linalg_sparse_tests
...
PiperOrigin-RevId: 507886628
2023-02-07 14:30:37 -08:00
Rahul Batra
01a10a1d06
[ROCm] Re-enable some linalg and sparse tests
2023-02-07 22:05:14 +00:00