18680 Commits

Author SHA1 Message Date
Parker Schuh
7ba8622719 For custom_partitioning, directly emit call when inside of a shard_map.
PiperOrigin-RevId: 592011427
2023-12-18 14:32:38 -08:00
jax authors
afdb7370b9 Merge pull request #19032 from jakevdp:upload-artifact
PiperOrigin-RevId: 591996988
2023-12-18 13:38:50 -08:00
George Necula
cc2a3eb564 Move export backwards compatibility tests out of jax2tf. Step 2.
These tests are independent of TensorFlow, yet by being in the jax2tf package they end up pulling in TensorFlow as a dependency.

This is part of a larger cl/562671314 that ran into OSS build problems.

This is step 2: moves the other test data Python files.

PiperOrigin-RevId: 591934999
2023-12-18 10:16:54 -08:00
Jake VanderPlas
e356d76913 Remove a number of deprecated APIs
All of these were deprecated prior to the JAX 0.4.16 release, on Sept 18 2023.
As of Monday Dec 18, we have met the 3 month deprecation period specified by the [API Compatiblity Policy](https://jax.readthedocs.io/en/latest/api_compatibility.html).

PiperOrigin-RevId: 591933493
2023-12-18 10:08:47 -08:00
George Necula
eed61f68aa Move export backwards compatibility tests out of jax2tf. Step 1.
These tests are independent of TensorFlow, yet by being in the jax2tf package they end up pulling in TensorFlow as a dependency.

This is part of a larger cl/562671314 that ran into OSS build problems.
I am attempting this smaller change first, and afterwards I will move more of the test data files, and then the actual test.

PiperOrigin-RevId: 591927484
2023-12-18 09:49:52 -08:00
Jake VanderPlas
873b33d7a5 CI: bump actions/upload-artifact from 3.1.3 to 4.0.0 2023-12-18 09:36:01 -08:00
jax authors
36cf5afa67 Merge pull request #19016 from gnecula:exp_fix_float0
PiperOrigin-RevId: 591869394
2023-12-18 05:42:17 -08:00
George Necula
7aba11f87f [export] Fix handling of float0 when exporting
There were two problems:
  * the float0 dtype was not part of the schema,
  * there was a bug invoking jax.vjp on a reloaded
    function, because of a mismatch between the type
    of symbolic zeros.

We changed the schema to add `f0`, but we add that
enum with a value larger than existing values, to
preserve backwards compatibility.
2023-12-18 14:51:27 +02:00
jax authors
259c285b10 [Jax] Enable jax_include_full_tracebacks_in_locations by default
PiperOrigin-RevId: 591783126
2023-12-17 21:56:13 -08:00
jax authors
c8963927e2 Update XLA dependency to use revision
fa9331a7e5.

PiperOrigin-RevId: 591770777
2023-12-17 20:28:23 -08:00
jax authors
e0cc9879d5 Merge pull request #19017 from gnecula:export_shard_map
PiperOrigin-RevId: 591614088
2023-12-16 22:26:37 -08:00
George Necula
d1e5199928 [export] Add test for exporting jit(shard_map) 2023-12-17 08:21:03 +02:00
jax authors
7af728541e Update XLA dependency to use revision
ac0cda8246.

PiperOrigin-RevId: 591600480
2023-12-16 20:26:31 -08:00
jax authors
6d876ed643 Update XLA dependency to use revision
dc2b009cc6.

PiperOrigin-RevId: 591430228
2023-12-15 20:51:41 -08:00
jax authors
732fd15be9 Merge pull request #19012 from skye:test_no_log_spam
PiperOrigin-RevId: 591400099
2023-12-15 17:23:04 -08:00
Skye Wanderman-Milne
904103eee2 Set TF_CPP_MIN_LOG_LEVEL=1 for new log spam test. 2023-12-15 17:18:46 -08:00
Jieying Luo
717059edcf Bump oldest support libtpu version to 20230927 which includes CopyToMemorySpace.
batch_device_put in libtpu will go through memory space path, and requires CopyToMemorySpace to be implemented if the backend uses memory space.

PiperOrigin-RevId: 591373300
2023-12-15 15:19:06 -08:00
jax authors
bc239fd1b3 Merge pull request #19005 from jakevdp:array-api-linalg
PiperOrigin-RevId: 591364100
2023-12-15 14:47:11 -08:00
jax authors
65d2f362ed Merge pull request #19006 from skye:test_no_log_spam
PiperOrigin-RevId: 591362840
2023-12-15 14:39:32 -08:00
Skye Wanderman-Milne
e5f7598166 Add unit test for catching log spam. 2023-12-15 14:29:40 -08:00
Jake VanderPlas
0c7b959dac jnp.linalg: add matrix_norm, matrix_transpose, vector_norm, vector_transpose
These have been added upstream to numpy.linalg in NumPy 2.0, as part of the Array API standard.
2023-12-15 14:17:36 -08:00
jax authors
64799a431a Merge pull request #19002 from jakevdp:core-dep
PiperOrigin-RevId: 591352495
2023-12-15 13:59:52 -08:00
Jake VanderPlas
adefbca642 jax.core: deprecate several private APIs 2023-12-15 13:37:09 -08:00
jax authors
9462aec52b Merge pull request #19000 from jakevdp:array-api-update
PiperOrigin-RevId: 591302038
2023-12-15 10:53:34 -08:00
Jake VanderPlas
ad67726c6e array-api: update CI to use latest array-api-tests commit 2023-12-15 10:49:56 -08:00
jax authors
eb595af36f Merge pull request #18989 from hawkinsp:fork
PiperOrigin-RevId: 591286756
2023-12-15 10:01:36 -08:00
Jieying Luo
c8b3567e82 Add two flags to support only building cuda kernel plugin or cuda pjrt plugin.
PiperOrigin-RevId: 591274120
2023-12-15 09:15:46 -08:00
Peter Hawkins
ec89e5e4c5 Add a warning if the user calls os.fork().
Fixes https://github.com/google/jax/issues/18852
2023-12-15 09:29:55 -05:00
Sergei Lebedev
41531123f4 Rolling back #18980, because it is not backwards compatible and breaks existing users.
Reverts 91faddd023c2df77df310f3f2f17eb2fa1e60df0

PiperOrigin-RevId: 591200403
2023-12-15 03:24:01 -08:00
Goran Flegar
a0458e6a1c Import new version of Triton
PiperOrigin-RevId: 591195548
2023-12-15 02:57:27 -08:00
jax authors
91faddd023 Merge pull request #18980 from gnecula:export_api
PiperOrigin-RevId: 591172917
2023-12-15 00:52:19 -08:00
George Necula
fd0f007765 [export] Refactor the imports for the public API of jax.experimental.export
Previously we used `from jax.experimental.export import export` and
`export.export(fun)`. Now we want to add the public API directly
to `jax.experimental.export`, for the following desired usage:

```
from jax.experimental import export

exp: export.Exported = export.export(fun)
ser: bytearray = export.serialize(exp)
exp1 = export.deserialized(ser)
export.call(exp1)
```

This change also includes a workaround to allow users to still
do `from jax.experimental.export import export`, for a while.
2023-12-15 10:00:05 +02:00
jax authors
14ce5f70c2 Update XLA dependency to use revision
56dca1628e.

PiperOrigin-RevId: 591140795
2023-12-14 21:39:31 -08:00
jax authors
6cd7adac99 Merge pull request #18956 from gnecula:export_effects
PiperOrigin-RevId: 591134940
2023-12-14 21:03:11 -08:00
George Necula
552010a381 [export] Fix the serialization of effects
We currently support only the serialization of effects with
nullary constructors. We must also ensure that upon deserialization
we produce an event that tests equal to the original one.
Here we add explicit error checks and tests.

We also make the CallTfEffect to have this property.
2023-12-15 06:52:45 +02:00
jax authors
a7b60234d9 Merge pull request #18991 from skye:revert_cuda_install
PiperOrigin-RevId: 591097432
2023-12-14 17:28:34 -08:00
Skye Wanderman-Milne
5d26c307ce Revert "Recommend the plugin in the CUDA installation instructions."
This reverts commit 78bc4f5ced41f7f96a70da769e7c3170dc2b3161.

GPU dlpack is broken with the new plugin. Recommend working GPU jaxlib until we can fix.
2023-12-14 17:23:18 -08:00
jax authors
891d44ccf5 Merge pull request #18983 from mattjj:eager-pmap-custom-jvp-vmap
PiperOrigin-RevId: 591085426
2023-12-14 16:36:38 -08:00
Matthew Johnson
b70ac9047d fix a bug with eager pmap + vmap + custom_jvp interaction
I used the same implementation technique in shard_map.py, e.g. in ShardMapTrace.process_custom_jvp_call, and it's sound, whereas I can't remember why we implementd the eager pmap stuff the way we did.

This fixes an internal test, but unfortunately I wasn't able to figure out a simple repro :/
2023-12-14 16:16:55 -08:00
jax authors
c4a9cee78a Merge pull request #18987 from hawkinsp:plugin
PiperOrigin-RevId: 591063011
2023-12-14 15:07:19 -08:00
Jieying Luo
1559d6495e Remove local version in jax-cuda-plugin and jax-cuda-pjrt package.
PiperOrigin-RevId: 591057013
2023-12-14 14:44:49 -08:00
Peter Hawkins
78bc4f5ced Recommend the plugin in the CUDA installation instructions. 2023-12-14 17:21:30 -05:00
Peter Hawkins
cf5a49584d Remove XLA logging patch, now the XLA fix has landed.
PiperOrigin-RevId: 590959603
2023-12-14 09:22:40 -08:00
Yash Katariya
8bf3a86860 [roll forward 2] Remove the `jax_require_devices_during_lowering flag since it was temporary. Added the semi-breaking change to Changelog.md.
Reverts b52bcc1639368069075284eefc763f824ca155f1

PiperOrigin-RevId: 590959383
2023-12-14 09:14:25 -08:00
Yash Katariya
6e1ab7ca3f Finish release of jax and jaxlib 0.4.23
PiperOrigin-RevId: 590833947
2023-12-13 23:39:08 -08:00
jax authors
296a7135f2 Update XLA dependency to use revision
352fc1e624.

PiperOrigin-RevId: 590818742
2023-12-13 22:19:42 -08:00
Yash Katariya
1c68577dcd Prepare for emergency jax and jaxlib 0.4.23 release
PiperOrigin-RevId: 590780824
jax-v0.4.23 jax-v0.4.23-rc
2023-12-13 19:02:24 -08:00
Peter Hawkins
b392622647 Add patch to suppress XLA:GPU logging.
PiperOrigin-RevId: 590780227
2023-12-13 18:53:50 -08:00
Yash Katariya
25c16c0b78 Finish jax and jaxlib 0.4.22 release
PiperOrigin-RevId: 590775311
2023-12-13 18:26:47 -08:00
Roy Frostig
3380b9feee split the random generalized normal test and skip its K-S half
It is key-sensitive and sometimes slow.

PiperOrigin-RevId: 590756597
2023-12-13 17:01:19 -08:00