15503 Commits

Author SHA1 Message Date
Jake VanderPlas
8f72454bdf Add internal jax.lax.asarray utility 2023-03-30 10:21:55 -07:00
Peter Hawkins
67a28ce30f Relax test tolerances for testLogisticPpf.
Fixes a test failure in CI.

PiperOrigin-RevId: 520649225
2023-03-30 08:41:56 -07:00
jax authors
dedfc8df75 Merge pull request #15282 from JiaYaobo:geom_random
PiperOrigin-RevId: 520635974
2023-03-30 07:45:19 -07:00
jax authors
1fd6e01289 Merge pull request #15287 from gnecula:tf_dim_vars
PiperOrigin-RevId: 520633830
2023-03-30 07:37:47 -07:00
jax authors
794769c113 Merge pull request #15302 from mattjj:pmap-pytree-prefix-errors
PiperOrigin-RevId: 520632081
2023-03-30 07:29:51 -07:00
jax authors
0a2e383eaf Merge pull request #15297 from jakevdp:finfo-props
PiperOrigin-RevId: 520632058
2023-03-30 07:22:28 -07:00
George Necula
081b86b82a [shape_poly] Improved computation of dimension variables for native serialization
Previously for native serialization we could only support polymorphic_shapes
where the specification was a simple dimension variable. E.g., we could not
handle a specification where `polymorphic_shapes="2*b"` because there was
no way to recover the value of `b` from the actual shape. (For non-native
serialization we were supporting some limited equation solving.)

The above is important, e.g., for the gradient of functions like
`jnp.concatenate([x, x])`, where the output shape if `2 *b`.

This is possible because in #15258 we have brought the computation
of the dimension variables into jax_export.

What we do here is to even out the support for native serialization to have
the same power as the non-native one. We do this by reusing the
same `shape_poly.prepare_dim_var_env` that we use for non-native
serialization.

After we land this, we will refactor the shape environment to be cleaner.
2023-03-30 15:51:24 +02:00
Peter Hawkins
47177e1417 Split more targets out the main JAX Bazel target.
Namely:
* abstract_arrays
* ad_util
* api_util
* interpreters/partial_eval
* lax_reference
PiperOrigin-RevId: 520618715
2023-03-30 06:12:45 -07:00
Matthew Johnson
81de5b7a0d improve pmap in_axes/out_axes pytree prefix error messages 2023-03-29 16:56:40 -07:00
Peter Hawkins
3135fbcd7f [JAX] Delete _DeviceArray and DeviceArray.
PiperOrigin-RevId: 520453090
2023-03-29 15:07:14 -07:00
jax authors
6ce5eeb85f Merge pull request #15299 from jakevdp:fix-docs
PiperOrigin-RevId: 520439293
2023-03-29 14:14:41 -07:00
Jake VanderPlas
8562a8d7bc DOC: pin pydata-sphinx-theme to fix incompatibility 2023-03-29 13:55:32 -07:00
Yash Katariya
830cd9fd98 Delete _single_device_array_from_buf since everything from JAX is an Array
PiperOrigin-RevId: 520418231
2023-03-29 12:59:12 -07:00
Jake VanderPlas
5759bf05df jnp.finfo: add missing properties 2023-03-29 11:23:51 -07:00
Peter Hawkins
f48dbf039e Use cast instead of pytype suppression in setops.py.
pytype cannot tell from the type signature that unique() returns an Array, not a tuple. Add a cast to help it along.

It's possible that a future use of @overload on the definition of jnp.unique might help.

PiperOrigin-RevId: 520389675
2023-03-29 11:17:00 -07:00
jiayaobo
924894e85c add geometric random gen
add geom random

add geom random

add geom random

add geom random
2023-03-30 02:08:04 +08:00
Roy Frostig
c8a7d5990d fix custom_jvp check for tracers in arguments marked nondiff_argnums
PiperOrigin-RevId: 520379098
2023-03-29 10:42:06 -07:00
Jake VanderPlas
f282c251d4 Add minimal pyproject.toml specifying build system
Replaces #15274, Fixes #15256

PiperOrigin-RevId: 520367622
jax-v0.4.8
2023-03-29 10:08:30 -07:00
jax authors
cfa330b6fa Merge pull request #15283 from JiaYaobo:fix_wald_doc
PiperOrigin-RevId: 520364879
2023-03-29 10:01:27 -07:00
jax authors
2d94f76ca3 Merge pull request #15278 from hawkinsp:cudainstall
PiperOrigin-RevId: 520364354
2023-03-29 09:53:58 -07:00
Yash Katariya
fbc05ee5ac Remove global_arg_shapes from pmap since it was only used for sharded_jit and sharded_jit was removed from JAX a long time ago
PiperOrigin-RevId: 520356179
2023-03-29 09:23:22 -07:00
jax authors
a964ae7fac Internal Code Change
PiperOrigin-RevId: 520341781
2023-03-29 08:23:56 -07:00
jax authors
7200d07db5 Merge pull request #15286 from hawkinsp:testjobs
PiperOrigin-RevId: 520319910
2023-03-29 06:39:04 -07:00
Peter Hawkins
d9b0f3cd6f Recommend --local_test_jobs in bazel test command line on GPU. 2023-03-29 09:28:53 -04:00
jax authors
07fc0222a1 Merge pull request #15279 from hawkinsp:versions
PiperOrigin-RevId: 520307157
2023-03-29 05:25:50 -07:00
jiayaobo
3a4d0b3552 remove scale in wald docstring 2023-03-29 11:39:49 +08:00
Peter Hawkins
705b5cc000 Add version constraints to CUDA pip wheel dependencies.
Fixes https://github.com/google/jax/issues/15267
2023-03-28 21:55:32 -04:00
Peter Hawkins
775f404f31 Update the CUDA installation instructions. 2023-03-28 21:46:07 -04:00
Peter Hawkins
c2d6fcc0e6 Split core.py and several files in an SCC with it into a separate Bazel build target.
PiperOrigin-RevId: 520192610
2023-03-28 18:31:13 -07:00
jax authors
8c4fed6410 Merge pull request #15270 from skye:pjrt_c_api
PiperOrigin-RevId: 520156646
2023-03-28 15:48:27 -07:00
Skye Wanderman-Milne
473d1c3685 Turn on PJRT C API by default.
I forgot that the default setting is actually in jaxlib:
fbe9a80fdb/xla/python/xla_client.py (L135)

To be able to make this change as a jax-only release, I manually set
the env var on Cloud TPU if it isn't already set.
2023-03-28 15:28:13 -07:00
Rebecca Chen
5ae2e79d43 Silence some pytype errors.
PiperOrigin-RevId: 520150523
2023-03-28 15:24:48 -07:00
jax authors
4061bbbbc2 Merge pull request #15269 from skye:min_jaxlib_version
PiperOrigin-RevId: 520127548
2023-03-28 14:02:27 -07:00
Skye Wanderman-Milne
00acf459c6 Bump minimum jaxlib version from 0.4.6 to 0.4.7.
Also removes a bunch of dead version guards (0.4.7 has
xla_extension_version 144 and mlir_api_version 47)
2023-03-28 13:43:01 -07:00
jax authors
014033b75e Merge pull request #15266 from mehdiataei:mehdiataei-patch-1
PiperOrigin-RevId: 520111030
2023-03-28 13:09:56 -07:00
jax authors
bbec461c8b Merge pull request #15263 from jakevdp:deprecations
PiperOrigin-RevId: 520110559
2023-03-28 13:02:32 -07:00
Jake VanderPlas
fc47137ca8 Add deprecation warnings for several top-level jax imports 2023-03-28 12:40:59 -07:00
Yash Katariya
2f105bde2d Jax 0.4.7 has been released so assert that length of warnings is 1 in test_cache_read_warning
PiperOrigin-RevId: 520098757
2023-03-28 12:17:30 -07:00
mehdiataei
8d090a01d0
Fixed spelling error in msgs 2023-03-28 14:48:18 -04:00
jax authors
2fbccc8f5e Merge pull request #15251 from jakevdp:mypy-deps
PiperOrigin-RevId: 520079899
2023-03-28 11:11:15 -07:00
Yash Katariya
7442faa715 Remove MeshPspecSharding since it has been more than 3 months since it was deprecated (Nov 2, 2022). The replacement name is NamedSharding.
PiperOrigin-RevId: 520072687
2023-03-28 10:47:42 -07:00
Yash Katariya
97c8ce31ed Deprecate FROM_GDA and remove its support from pjit's code since jax.Array inside pjit has sharding inference capabilities by default.
PiperOrigin-RevId: 520067392
2023-03-28 10:33:00 -07:00
Colin Gaffney
6ace6667dd Set coordinator address to allow it to later be used to initialize OCDBT coordinator server. Allow user to pass ts.Context when serializing or deserializing.
PiperOrigin-RevId: 520064049
2023-03-28 10:25:07 -07:00
George Necula
2ac2dc65b1 Remove jax2tf experimental_native_lowering.
Users should use native_serialization.

PiperOrigin-RevId: 520063928
2023-03-28 10:17:58 -07:00
Yash Katariya
86c0b36bfd Remove Cuda 11.4 support. JAX from 0.4.8 release will support cuda 11.8 and cuda 12
PiperOrigin-RevId: 520056811
2023-03-28 09:54:36 -07:00
jax authors
4106c35e0a Merge pull request #15258 from gnecula:dim_vars
PiperOrigin-RevId: 520033493
2023-03-28 08:21:00 -07:00
George Necula
a1538c73b4 [shape_poly] Refactor the computation of the dimension variables in native serialization
Currently, JAX native serialization produces a module whose main function
takes additional arguments for the values of the dimension variables. These
values are then resolved in the XlaCallModule based on a dim_args_spec
parameter.

We move the code that computes the dimension variables from XlaCallModule to
jax_export following pretty much the same technique. This simplifies
XlaCallModule and especially its API (the dim_args_spec). So far this
is just a refactoring with no semantic changes, but this will allow us
to improve the support for dimension variables that occur in linear
polynomials, e.g., "2*b" rather than just "b".
2023-03-28 12:51:48 +02:00
jax authors
4533578bba Merge pull request #15206 from jakevdp:expm-batch
PiperOrigin-RevId: 519883991
2023-03-27 18:21:27 -07:00
jax authors
03d1442c4f Merge pull request #15241 from jakevdp:instance-check
PiperOrigin-RevId: 519868384
2023-03-27 17:06:00 -07:00
jax authors
6f39237bc0 Merge pull request #15243 from jakevdp:coo-sort-warning
PiperOrigin-RevId: 519864866
2023-03-27 16:50:24 -07:00