14753 Commits

Author SHA1 Message Date
Peter Hawkins
6ee67639e2 Split PyTorch interoperability tests into their own test.
PiperOrigin-RevId: 508722180
2023-02-10 12:17:11 -08:00
jax authors
5da5967d08 Merge pull request #14395 from jakevdp:bcsr-dot-general
PiperOrigin-RevId: 508721790
2023-02-10 12:09:29 -08:00
Jake VanderPlas
ac647b9459 [sparse] implement autodiff rules for bcsr_dot_general 2023-02-10 12:00:30 -08:00
jax authors
7a864d73bc Merge pull request #14394 from jakevdp:jax-array-methods
PiperOrigin-RevId: 508694486
2023-02-10 10:27:14 -08:00
George Necula
be21404085 [jax2tf] Add shard_map tests
Also fix tests to run on multiple devices in TF

PiperOrigin-RevId: 508691872
2023-02-10 10:18:19 -08:00
jax authors
d09f3c2ee4 Merge pull request #11727 from gnecula:call_tf_checks
PiperOrigin-RevId: 508685246
2023-02-10 09:51:35 -08:00
Jake VanderPlas
60256df668 [typing] define additional methods & properties on jax.Array
These are the methods that are only valid for actual materialized arrays (i.e. not Tracers)
In order to simplify the experience for users, we want to maintain only a single jax.Array
type, so we define all methods here and raise explicit errors on Tracer instances.
2023-02-10 09:42:32 -08:00
John QiangZhang
7659a3a271 Enable call_tf_native_lowering_test.
PiperOrigin-RevId: 508677359
2023-02-10 09:16:53 -08:00
jax authors
9f0783f35d Merge pull request #14403 from gnecula:reduce_precision
PiperOrigin-RevId: 508635187
2023-02-10 05:38:59 -08:00
jax authors
f070557260 Merge pull request #14400 from gnecula:native_bug1
PiperOrigin-RevId: 508635169
2023-02-10 05:30:24 -08:00
George Necula
30fda87142 [call_tf] Improve error reporting
Add more checks to catch early the cases when the called TF function
returns values that are not convertible to JAX values (arrays of
numeric values). All these cases were resulting in errors even before
but sometimes these errors were deep in the stack and harder to
diagnose.
2023-02-10 14:19:49 +01:00
George Necula
48c2538365 [jax2tf] Add support for reduce_precision 2023-02-10 13:29:46 +01:00
George Necula
ff6051fc31 [shape_poly] Better error message for functions that do not use input arguments
Also:
  * fixed some of the tests that were using the shape but not the value
  of the input arguments
  * fix importing of mlir.py due to recent move of interpreters.mlir to
  _src.interpreters.mlir
2023-02-10 10:59:46 +01:00
Peter Hawkins
54ff78dbde Deprecate jax.interpreters.xla.Device and jax.interpreters.xla.DeviceArray.
PiperOrigin-RevId: 508502470
2023-02-09 16:11:48 -08:00
jax authors
357b48d29a Merge pull request #14391 from ROCmSoftwarePlatform:rocm_switch_to_rocm54
PiperOrigin-RevId: 508497281
2023-02-09 15:50:49 -08:00
Roy Frostig
1c84e4a753 migrate internal dependencies from jax.interpreters.batching to jax._src.interpreters.batching
... in preparation for paring down `jax.interpreters.batching`'s exported symbols.

PiperOrigin-RevId: 508487887
2023-02-09 15:11:57 -08:00
jax authors
12dc73dc6e Merge pull request #14388 from jakevdp:bcsr-todense-ad
PiperOrigin-RevId: 508477843
2023-02-09 14:41:41 -08:00
Jieying Luo
668b82d529 [PJRT C API] Register a backend factory for every PJRT plugin set in PJRT_NAMES_AND_LIBRARY_PATHS.
Loading TPU PJRT plugin is moved to make_tpu_client.

This change is based on https://github.com/google/jax/pull/14011.

PiperOrigin-RevId: 508477737
2023-02-09 14:33:46 -08:00
Jake VanderPlas
7651866b1d [sparse] implement autodiff rules for bcsr primitives 2023-02-09 14:19:22 -08:00
Jake VanderPlas
15c9bca67f [sparse] add cusparse lowering for simplest cases of bcsr_dot_general
PiperOrigin-RevId: 508473938
2023-02-09 14:18:44 -08:00
jax authors
253cd4d9d1 Merge pull request #14387 from ROCmSoftwarePlatform:rocm_reenable_dirichlet_test
PiperOrigin-RevId: 508466026
2023-02-09 13:50:06 -08:00
Peter Hawkins
88cc254f2c [JAX] Replace uses of jax.interpreters.pxla.ShardedDeviceArray with jax.Array.
PiperOrigin-RevId: 508463147
2023-02-09 13:39:41 -08:00
Peter Hawkins
0c14e9ab49 Change jax.ad, jax.xla, jax.pxla to point to the shims instead of the internal modules.
Don't hide _deprecations in shim modules, since it's handy for users to override deprecations locally, e.g., to verify there are no remaining users.

Fix some overly-strict type annotations.

PiperOrigin-RevId: 508461199
2023-02-09 13:31:40 -08:00
jax authors
adcceb228f Merge pull request #14384 from mattjj:pjit-pretty-print
PiperOrigin-RevId: 508454299
2023-02-09 13:04:58 -08:00
Matthew Johnson
a964dc3b9a simpler pretty-print for pjit, tweak custom pp rule signature 2023-02-09 12:45:51 -08:00
Rahul Batra
7d0d9b706e [ROCm]: Re-enable Dirichlet Tests on ROCm 2023-02-09 20:19:07 +00:00
Rahul Batra
023226e181 [ROCm]: Move dockerfile to ROCm5.4 2023-02-09 20:08:35 +00:00
Peter Hawkins
8268cd562d Add infrastructure for managing deprecations.
Use it to deprecate jax.experimental.PartitionSpec, jax.interpreters.pxla.PartitionSpec, jax.interpreters.pxla.Mesh.

PiperOrigin-RevId: 508349776
2023-02-09 05:48:40 -08:00
jax authors
3f8cb0a7c9 Merge pull request #14379 from mattjj:shmap-vmap-spmd-axis-name
PiperOrigin-RevId: 508292029
2023-02-09 00:14:30 -08:00
Matthew Johnson
6fb3ace5d0 [shard-map] add vmap spmd_axis_name support, fix vmap rule bug 2023-02-08 23:54:28 -08:00
jax authors
bd7c227e96 Merge pull request #14373 from mattjj:shmap-check-rep-false
PiperOrigin-RevId: 508219490
2023-02-08 16:49:29 -08:00
Matthew Johnson
1a03f34383 [shard-map] if check_rep=False, don't call rep rules in eager 2023-02-08 15:42:35 -08:00
jax authors
ccb974a150 Merge pull request #14370 from jakevdp:argpartition-impl
PiperOrigin-RevId: 508194466
2023-02-08 15:10:50 -08:00
Peter Hawkins
a28b01243b Move contents of jax.monitoring to jax._src.monitoring.
PiperOrigin-RevId: 508191560
2023-02-08 15:03:22 -08:00
Yash Katariya
7350f00acd Remove jax_experimental_subjaxpr_lowering_cache since it was only for jit and was False by default. Now that jit/pjit are merged, this cache is not needed since pjit does the caching and we get it for free.
PiperOrigin-RevId: 508191408
2023-02-08 14:55:56 -08:00
Jake VanderPlas
4fbaee5920 Implement jax.numpy.argpartition 2023-02-08 14:41:39 -08:00
Peter Hawkins
cc8d7fae32 Move jax.interpreters.mlir to jax._src.interpreters.mlir.
Replace jax.interpreters.mlir with a shim that re-exports names that are likely to be used externally.

PiperOrigin-RevId: 508187063
2023-02-08 14:39:01 -08:00
jax authors
3e349c7bed Merge pull request #14361 from jakevdp:doc-topk
PiperOrigin-RevId: 508181335
2023-02-08 14:19:01 -08:00
Yash Katariya
e4d551a217 Remove the doctest skip now that jit and pjit have been merged
PiperOrigin-RevId: 508162840
2023-02-08 13:09:53 -08:00
jax authors
1254d44dbd Remove silent data corruption runtime flags from persistent cache key.
These flags have no effect on the compiled executable, just the runtime execution.

PiperOrigin-RevId: 508152877
2023-02-08 12:31:27 -08:00
Ashish Shenoy
f71a55c554 Rename tensorflow core target variable to tensorflow_core
PiperOrigin-RevId: 508148106
2023-02-08 12:11:59 -08:00
Yash Katariya
6ec9082cf5 Default jax_jit_pjit_api_merge to True. This means that the implementation of jit and pjit have been merged but they still remain separate APIs due to the semantic difference of how they behave under the Mesh context manager.
This changes the internals of JAX without affecting any public API.

Before, `jit` was a final style primitive. This means that the creation
of jaxpr was delayed as much as possible and transformations were stacked
on top of each other. With the `jit`-`pjit` implementation merge, `jit`
becomes an initial style primitive which means that we trace to jaxpr
as early as possible. For more information see [this section in autodidax](https://jax.readthedocs.io/en/latest/autodidax.html#on-the-fly-final-style-and-staged-initial-style-processing).

Moving to initial style should simplify JAX's internals and make
development of features like dynamic shapes, etc easier.

PiperOrigin-RevId: 508143501
2023-02-08 11:55:48 -08:00
jax authors
9a1f9b1ef8 Merge pull request #14362 from mattjj:shmap-remat
PiperOrigin-RevId: 508139783
2023-02-08 11:42:12 -08:00
Matthew Johnson
58d3f552d7 [shard-map] add remat support, very basic test 2023-02-08 11:15:38 -08:00
Jake VanderPlas
3c6183498a lax.top_k: improve documentation and errors on invalid values 2023-02-08 11:07:56 -08:00
jax authors
4844e3f85c Merge pull request #14357 from skye:version
PiperOrigin-RevId: 508119911
2023-02-08 10:42:48 -08:00
jax authors
4358b803e9 Merge pull request #14355 from jakevdp:tril-indices
PiperOrigin-RevId: 508119785
2023-02-08 10:35:10 -08:00
Yash Katariya
7b1128fdc4 Use jnp.arange to break the pjit cache (when jit and pjit are merged) because pytest runs tests non-hermetically.
PiperOrigin-RevId: 508114498
2023-02-08 10:17:37 -08:00
Skye Wanderman-Milne
21f12183bf Post 0.4.3 release updates 2023-02-08 10:08:59 -08:00
Jake VanderPlas
a76a024548 tril/triu_indices: compute arrays at runtime 2023-02-08 09:52:41 -08:00