Peter Hawkins
6ee67639e2
Split PyTorch interoperability tests into their own test.
...
PiperOrigin-RevId: 508722180
2023-02-10 12:17:11 -08:00
jax authors
5da5967d08
Merge pull request #14395 from jakevdp:bcsr-dot-general
...
PiperOrigin-RevId: 508721790
2023-02-10 12:09:29 -08:00
Jake VanderPlas
ac647b9459
[sparse] implement autodiff rules for bcsr_dot_general
2023-02-10 12:00:30 -08:00
jax authors
7a864d73bc
Merge pull request #14394 from jakevdp:jax-array-methods
...
PiperOrigin-RevId: 508694486
2023-02-10 10:27:14 -08:00
George Necula
be21404085
[jax2tf] Add shard_map tests
...
Also fix tests to run on multiple devices in TF
PiperOrigin-RevId: 508691872
2023-02-10 10:18:19 -08:00
jax authors
d09f3c2ee4
Merge pull request #11727 from gnecula:call_tf_checks
...
PiperOrigin-RevId: 508685246
2023-02-10 09:51:35 -08:00
Jake VanderPlas
60256df668
[typing] define additional methods & properties on jax.Array
...
These are the methods that are only valid for actual materialized arrays (i.e. not Tracers)
In order to simplify the experience for users, we want to maintain only a single jax.Array
type, so we define all methods here and raise explicit errors on Tracer instances.
2023-02-10 09:42:32 -08:00
John QiangZhang
7659a3a271
Enable call_tf_native_lowering_test.
...
PiperOrigin-RevId: 508677359
2023-02-10 09:16:53 -08:00
jax authors
9f0783f35d
Merge pull request #14403 from gnecula:reduce_precision
...
PiperOrigin-RevId: 508635187
2023-02-10 05:38:59 -08:00
jax authors
f070557260
Merge pull request #14400 from gnecula:native_bug1
...
PiperOrigin-RevId: 508635169
2023-02-10 05:30:24 -08:00
George Necula
30fda87142
[call_tf] Improve error reporting
...
Add more checks to catch early the cases when the called TF function
returns values that are not convertible to JAX values (arrays of
numeric values). All these cases were resulting in errors even before
but sometimes these errors were deep in the stack and harder to
diagnose.
2023-02-10 14:19:49 +01:00
George Necula
48c2538365
[jax2tf] Add support for reduce_precision
2023-02-10 13:29:46 +01:00
George Necula
ff6051fc31
[shape_poly] Better error message for functions that do not use input arguments
...
Also:
* fixed some of the tests that were using the shape but not the value
of the input arguments
* fix importing of mlir.py due to recent move of interpreters.mlir to
_src.interpreters.mlir
2023-02-10 10:59:46 +01:00
Peter Hawkins
54ff78dbde
Deprecate jax.interpreters.xla.Device and jax.interpreters.xla.DeviceArray.
...
PiperOrigin-RevId: 508502470
2023-02-09 16:11:48 -08:00
jax authors
357b48d29a
Merge pull request #14391 from ROCmSoftwarePlatform:rocm_switch_to_rocm54
...
PiperOrigin-RevId: 508497281
2023-02-09 15:50:49 -08:00
Roy Frostig
1c84e4a753
migrate internal dependencies from jax.interpreters.batching
to jax._src.interpreters.batching
...
... in preparation for paring down `jax.interpreters.batching`'s exported symbols.
PiperOrigin-RevId: 508487887
2023-02-09 15:11:57 -08:00
jax authors
12dc73dc6e
Merge pull request #14388 from jakevdp:bcsr-todense-ad
...
PiperOrigin-RevId: 508477843
2023-02-09 14:41:41 -08:00
Jieying Luo
668b82d529
[PJRT C API] Register a backend factory for every PJRT plugin set in PJRT_NAMES_AND_LIBRARY_PATHS.
...
Loading TPU PJRT plugin is moved to make_tpu_client.
This change is based on https://github.com/google/jax/pull/14011 .
PiperOrigin-RevId: 508477737
2023-02-09 14:33:46 -08:00
Jake VanderPlas
7651866b1d
[sparse] implement autodiff rules for bcsr primitives
2023-02-09 14:19:22 -08:00
Jake VanderPlas
15c9bca67f
[sparse] add cusparse lowering for simplest cases of bcsr_dot_general
...
PiperOrigin-RevId: 508473938
2023-02-09 14:18:44 -08:00
jax authors
253cd4d9d1
Merge pull request #14387 from ROCmSoftwarePlatform:rocm_reenable_dirichlet_test
...
PiperOrigin-RevId: 508466026
2023-02-09 13:50:06 -08:00
Peter Hawkins
88cc254f2c
[JAX] Replace uses of jax.interpreters.pxla.ShardedDeviceArray with jax.Array.
...
PiperOrigin-RevId: 508463147
2023-02-09 13:39:41 -08:00
Peter Hawkins
0c14e9ab49
Change jax.ad, jax.xla, jax.pxla to point to the shims instead of the internal modules.
...
Don't hide _deprecations in shim modules, since it's handy for users to override deprecations locally, e.g., to verify there are no remaining users.
Fix some overly-strict type annotations.
PiperOrigin-RevId: 508461199
2023-02-09 13:31:40 -08:00
jax authors
adcceb228f
Merge pull request #14384 from mattjj:pjit-pretty-print
...
PiperOrigin-RevId: 508454299
2023-02-09 13:04:58 -08:00
Matthew Johnson
a964dc3b9a
simpler pretty-print for pjit, tweak custom pp rule signature
2023-02-09 12:45:51 -08:00
Rahul Batra
7d0d9b706e
[ROCm]: Re-enable Dirichlet Tests on ROCm
2023-02-09 20:19:07 +00:00
Rahul Batra
023226e181
[ROCm]: Move dockerfile to ROCm5.4
2023-02-09 20:08:35 +00:00
Peter Hawkins
8268cd562d
Add infrastructure for managing deprecations.
...
Use it to deprecate jax.experimental.PartitionSpec, jax.interpreters.pxla.PartitionSpec, jax.interpreters.pxla.Mesh.
PiperOrigin-RevId: 508349776
2023-02-09 05:48:40 -08:00
jax authors
3f8cb0a7c9
Merge pull request #14379 from mattjj:shmap-vmap-spmd-axis-name
...
PiperOrigin-RevId: 508292029
2023-02-09 00:14:30 -08:00
Matthew Johnson
6fb3ace5d0
[shard-map] add vmap spmd_axis_name support, fix vmap rule bug
2023-02-08 23:54:28 -08:00
jax authors
bd7c227e96
Merge pull request #14373 from mattjj:shmap-check-rep-false
...
PiperOrigin-RevId: 508219490
2023-02-08 16:49:29 -08:00
Matthew Johnson
1a03f34383
[shard-map] if check_rep=False, don't call rep rules in eager
2023-02-08 15:42:35 -08:00
jax authors
ccb974a150
Merge pull request #14370 from jakevdp:argpartition-impl
...
PiperOrigin-RevId: 508194466
2023-02-08 15:10:50 -08:00
Peter Hawkins
a28b01243b
Move contents of jax.monitoring to jax._src.monitoring.
...
PiperOrigin-RevId: 508191560
2023-02-08 15:03:22 -08:00
Yash Katariya
7350f00acd
Remove jax_experimental_subjaxpr_lowering_cache
since it was only for jit
and was False
by default. Now that jit/pjit are merged, this cache is not needed since pjit does the caching and we get it for free.
...
PiperOrigin-RevId: 508191408
2023-02-08 14:55:56 -08:00
Jake VanderPlas
4fbaee5920
Implement jax.numpy.argpartition
2023-02-08 14:41:39 -08:00
Peter Hawkins
cc8d7fae32
Move jax.interpreters.mlir to jax._src.interpreters.mlir.
...
Replace jax.interpreters.mlir with a shim that re-exports names that are likely to be used externally.
PiperOrigin-RevId: 508187063
2023-02-08 14:39:01 -08:00
jax authors
3e349c7bed
Merge pull request #14361 from jakevdp:doc-topk
...
PiperOrigin-RevId: 508181335
2023-02-08 14:19:01 -08:00
Yash Katariya
e4d551a217
Remove the doctest skip now that jit and pjit have been merged
...
PiperOrigin-RevId: 508162840
2023-02-08 13:09:53 -08:00
jax authors
1254d44dbd
Remove silent data corruption runtime flags from persistent cache key.
...
These flags have no effect on the compiled executable, just the runtime execution.
PiperOrigin-RevId: 508152877
2023-02-08 12:31:27 -08:00
Ashish Shenoy
f71a55c554
Rename tensorflow core target variable to tensorflow_core
...
PiperOrigin-RevId: 508148106
2023-02-08 12:11:59 -08:00
Yash Katariya
6ec9082cf5
Default jax_jit_pjit_api_merge
to True. This means that the implementation of jit and pjit have been merged but they still remain separate APIs due to the semantic difference of how they behave under the Mesh
context manager.
...
This changes the internals of JAX without affecting any public API.
Before, `jit` was a final style primitive. This means that the creation
of jaxpr was delayed as much as possible and transformations were stacked
on top of each other. With the `jit`-`pjit` implementation merge, `jit`
becomes an initial style primitive which means that we trace to jaxpr
as early as possible. For more information see [this section in autodidax](https://jax.readthedocs.io/en/latest/autodidax.html#on-the-fly-final-style-and-staged-initial-style-processing ).
Moving to initial style should simplify JAX's internals and make
development of features like dynamic shapes, etc easier.
PiperOrigin-RevId: 508143501
2023-02-08 11:55:48 -08:00
jax authors
9a1f9b1ef8
Merge pull request #14362 from mattjj:shmap-remat
...
PiperOrigin-RevId: 508139783
2023-02-08 11:42:12 -08:00
Matthew Johnson
58d3f552d7
[shard-map] add remat support, very basic test
2023-02-08 11:15:38 -08:00
Jake VanderPlas
3c6183498a
lax.top_k: improve documentation and errors on invalid values
2023-02-08 11:07:56 -08:00
jax authors
4844e3f85c
Merge pull request #14357 from skye:version
...
PiperOrigin-RevId: 508119911
2023-02-08 10:42:48 -08:00
jax authors
4358b803e9
Merge pull request #14355 from jakevdp:tril-indices
...
PiperOrigin-RevId: 508119785
2023-02-08 10:35:10 -08:00
Yash Katariya
7b1128fdc4
Use jnp.arange to break the pjit cache (when jit and pjit are merged) because pytest runs tests non-hermetically.
...
PiperOrigin-RevId: 508114498
2023-02-08 10:17:37 -08:00
Skye Wanderman-Milne
21f12183bf
Post 0.4.3 release updates
2023-02-08 10:08:59 -08:00
Jake VanderPlas
a76a024548
tril/triu_indices: compute arrays at runtime
2023-02-08 09:52:41 -08:00