15481 Commits

Author SHA1 Message Date
Yash Katariya
fbc05ee5ac Remove global_arg_shapes from pmap since it was only used for sharded_jit and sharded_jit was removed from JAX a long time ago
PiperOrigin-RevId: 520356179
2023-03-29 09:23:22 -07:00
jax authors
a964ae7fac Internal Code Change
PiperOrigin-RevId: 520341781
2023-03-29 08:23:56 -07:00
jax authors
7200d07db5 Merge pull request #15286 from hawkinsp:testjobs
PiperOrigin-RevId: 520319910
2023-03-29 06:39:04 -07:00
Peter Hawkins
d9b0f3cd6f Recommend --local_test_jobs in bazel test command line on GPU. 2023-03-29 09:28:53 -04:00
jax authors
07fc0222a1 Merge pull request #15279 from hawkinsp:versions
PiperOrigin-RevId: 520307157
2023-03-29 05:25:50 -07:00
Peter Hawkins
705b5cc000 Add version constraints to CUDA pip wheel dependencies.
Fixes https://github.com/google/jax/issues/15267
2023-03-28 21:55:32 -04:00
Peter Hawkins
c2d6fcc0e6 Split core.py and several files in an SCC with it into a separate Bazel build target.
PiperOrigin-RevId: 520192610
2023-03-28 18:31:13 -07:00
jax authors
8c4fed6410 Merge pull request #15270 from skye:pjrt_c_api
PiperOrigin-RevId: 520156646
2023-03-28 15:48:27 -07:00
Skye Wanderman-Milne
473d1c3685 Turn on PJRT C API by default.
I forgot that the default setting is actually in jaxlib:
fbe9a80fdb/xla/python/xla_client.py (L135)

To be able to make this change as a jax-only release, I manually set
the env var on Cloud TPU if it isn't already set.
2023-03-28 15:28:13 -07:00
Rebecca Chen
5ae2e79d43 Silence some pytype errors.
PiperOrigin-RevId: 520150523
2023-03-28 15:24:48 -07:00
jax authors
4061bbbbc2 Merge pull request #15269 from skye:min_jaxlib_version
PiperOrigin-RevId: 520127548
2023-03-28 14:02:27 -07:00
Skye Wanderman-Milne
00acf459c6 Bump minimum jaxlib version from 0.4.6 to 0.4.7.
Also removes a bunch of dead version guards (0.4.7 has
xla_extension_version 144 and mlir_api_version 47)
2023-03-28 13:43:01 -07:00
jax authors
014033b75e Merge pull request #15266 from mehdiataei:mehdiataei-patch-1
PiperOrigin-RevId: 520111030
2023-03-28 13:09:56 -07:00
jax authors
bbec461c8b Merge pull request #15263 from jakevdp:deprecations
PiperOrigin-RevId: 520110559
2023-03-28 13:02:32 -07:00
Jake VanderPlas
fc47137ca8 Add deprecation warnings for several top-level jax imports 2023-03-28 12:40:59 -07:00
Yash Katariya
2f105bde2d Jax 0.4.7 has been released so assert that length of warnings is 1 in test_cache_read_warning
PiperOrigin-RevId: 520098757
2023-03-28 12:17:30 -07:00
mehdiataei
8d090a01d0
Fixed spelling error in msgs 2023-03-28 14:48:18 -04:00
jax authors
2fbccc8f5e Merge pull request #15251 from jakevdp:mypy-deps
PiperOrigin-RevId: 520079899
2023-03-28 11:11:15 -07:00
Yash Katariya
7442faa715 Remove MeshPspecSharding since it has been more than 3 months since it was deprecated (Nov 2, 2022). The replacement name is NamedSharding.
PiperOrigin-RevId: 520072687
2023-03-28 10:47:42 -07:00
Yash Katariya
97c8ce31ed Deprecate FROM_GDA and remove its support from pjit's code since jax.Array inside pjit has sharding inference capabilities by default.
PiperOrigin-RevId: 520067392
2023-03-28 10:33:00 -07:00
Colin Gaffney
6ace6667dd Set coordinator address to allow it to later be used to initialize OCDBT coordinator server. Allow user to pass ts.Context when serializing or deserializing.
PiperOrigin-RevId: 520064049
2023-03-28 10:25:07 -07:00
George Necula
2ac2dc65b1 Remove jax2tf experimental_native_lowering.
Users should use native_serialization.

PiperOrigin-RevId: 520063928
2023-03-28 10:17:58 -07:00
Yash Katariya
86c0b36bfd Remove Cuda 11.4 support. JAX from 0.4.8 release will support cuda 11.8 and cuda 12
PiperOrigin-RevId: 520056811
2023-03-28 09:54:36 -07:00
jax authors
4106c35e0a Merge pull request #15258 from gnecula:dim_vars
PiperOrigin-RevId: 520033493
2023-03-28 08:21:00 -07:00
George Necula
a1538c73b4 [shape_poly] Refactor the computation of the dimension variables in native serialization
Currently, JAX native serialization produces a module whose main function
takes additional arguments for the values of the dimension variables. These
values are then resolved in the XlaCallModule based on a dim_args_spec
parameter.

We move the code that computes the dimension variables from XlaCallModule to
jax_export following pretty much the same technique. This simplifies
XlaCallModule and especially its API (the dim_args_spec). So far this
is just a refactoring with no semantic changes, but this will allow us
to improve the support for dimension variables that occur in linear
polynomials, e.g., "2*b" rather than just "b".
2023-03-28 12:51:48 +02:00
jax authors
4533578bba Merge pull request #15206 from jakevdp:expm-batch
PiperOrigin-RevId: 519883991
2023-03-27 18:21:27 -07:00
jax authors
03d1442c4f Merge pull request #15241 from jakevdp:instance-check
PiperOrigin-RevId: 519868384
2023-03-27 17:06:00 -07:00
jax authors
6f39237bc0 Merge pull request #15243 from jakevdp:coo-sort-warning
PiperOrigin-RevId: 519864866
2023-03-27 16:50:24 -07:00
Jake VanderPlas
ad0fc8979b jax.scipy.linalg.expm: support batched inputs 2023-03-27 16:39:48 -07:00
Jake VanderPlas
61190cbcb0 CI: add numpy & scipy to mypy env 2023-03-27 15:08:44 -07:00
Yash Katariya
670fba3a91 Finish jax and jaxlib 0.4.7 release
PiperOrigin-RevId: 519839723
2023-03-27 15:06:38 -07:00
Sharad Vikram
10dc941d8d Add jaxlib version guard for rnn test
PiperOrigin-RevId: 519833650
2023-03-27 14:43:46 -07:00
Peter Hawkins
6cc1bf54a1 Move jax.interpreters.partial_eval to jax._src.interpreters.partial_eval.
Also fix up some other internal imports of jax.interpreters.* to use jax._src.interpreters.

PiperOrigin-RevId: 519813664
2023-03-27 13:30:47 -07:00
Yash Katariya
ae4f1fcb66 Update the commit in workspace too
PiperOrigin-RevId: 519797427
jax-v0.4.7 jaxlib-v0.4.7 jax-v0.4.7-rc1
2023-03-27 12:30:18 -07:00
Yash Katariya
e9cac5eb47 Prepare for jax and jaxlib 0.4.7 release
PiperOrigin-RevId: 519785176
jax-v0.4.7-rc
2023-03-27 11:45:22 -07:00
Yash Katariya
e21aee18a8 Add deprecation warning for FROM_GDA usage since that argument is not required anymore.
PiperOrigin-RevId: 519781715
2023-03-27 11:33:11 -07:00
Sharad Vikram
3c3fa042e3 Copy seq_lengths before creating descriptor
PiperOrigin-RevId: 519771897
2023-03-27 10:59:44 -07:00
Peter Hawkins
88c2898e36 Use pytype_strict_library() in Bazel build rules.
PiperOrigin-RevId: 519757928
2023-03-27 10:16:08 -07:00
Jake VanderPlas
392bd93e4d [sparse] fix coo efficiency warning 2023-03-27 10:15:43 -07:00
Jake VanderPlas
ed9fa1342b jax.typing: recommend instance check in Python 3.10 or newer 2023-03-27 10:01:28 -07:00
Peter Hawkins
40fb646e35 Fix duplicate definition of 'cuda' extra in setup.py.
PiperOrigin-RevId: 519750659
2023-03-27 09:52:37 -07:00
jax authors
af4d4943a7 Merge pull request #8633 from shawwn:2021-11-19/autodidax-fix-jaxpr-subcomp-return-type
PiperOrigin-RevId: 519745476
2023-03-27 09:52:20 -07:00
jax authors
10d51c78f6 Merge pull request #15218 from hawkinsp:mypy
PiperOrigin-RevId: 519745465
2023-03-27 09:37:54 -07:00
Yash Katariya
41695cc78c Temporarily fix the compilation cache test which is failing on latest jaxlib release
PiperOrigin-RevId: 519745099
2023-03-27 09:37:37 -07:00
jax authors
2c4be6f662 Merge pull request #15226 from canyon289:patch-1
PiperOrigin-RevId: 519743393
2023-03-27 09:30:24 -07:00
jax authors
d19e60ea07 Merge pull request #15228 from canyon289:patch-2
PiperOrigin-RevId: 519742908
2023-03-27 09:30:07 -07:00
Yash Katariya
cf8c2b8450 Delete benchmark and pmap_benchmark files as they are legacy and replaced with api_benchmark.py
PiperOrigin-RevId: 519742866
2023-03-27 09:22:57 -07:00
jax authors
d473e86912 Merge pull request #13008 from hawkinsp:pipcuda
PiperOrigin-RevId: 519740461
2023-03-27 09:14:24 -07:00
jax authors
6715736583 Merge pull request #15205 from yhtang:editable-jaxlib-build
PiperOrigin-RevId: 519704474
2023-03-27 06:33:31 -07:00
jax authors
f3613a11b9 Merge pull request #15234 from gnecula:get_dim_size
PiperOrigin-RevId: 519691037
2023-03-27 05:21:54 -07:00