9996 Commits

Author SHA1 Message Date
jax authors
dedfc8df75 Merge pull request #15282 from JiaYaobo:geom_random
PiperOrigin-RevId: 520635974
2023-03-30 07:45:19 -07:00
jax authors
1fd6e01289 Merge pull request #15287 from gnecula:tf_dim_vars
PiperOrigin-RevId: 520633830
2023-03-30 07:37:47 -07:00
jax authors
794769c113 Merge pull request #15302 from mattjj:pmap-pytree-prefix-errors
PiperOrigin-RevId: 520632081
2023-03-30 07:29:51 -07:00
jax authors
0a2e383eaf Merge pull request #15297 from jakevdp:finfo-props
PiperOrigin-RevId: 520632058
2023-03-30 07:22:28 -07:00
George Necula
081b86b82a [shape_poly] Improved computation of dimension variables for native serialization
Previously for native serialization we could only support polymorphic_shapes
where the specification was a simple dimension variable. E.g., we could not
handle a specification where `polymorphic_shapes="2*b"` because there was
no way to recover the value of `b` from the actual shape. (For non-native
serialization we were supporting some limited equation solving.)

The above is important, e.g., for the gradient of functions like
`jnp.concatenate([x, x])`, where the output shape if `2 *b`.

This is possible because in #15258 we have brought the computation
of the dimension variables into jax_export.

What we do here is to even out the support for native serialization to have
the same power as the non-native one. We do this by reusing the
same `shape_poly.prepare_dim_var_env` that we use for non-native
serialization.

After we land this, we will refactor the shape environment to be cleaner.
2023-03-30 15:51:24 +02:00
Peter Hawkins
47177e1417 Split more targets out the main JAX Bazel target.
Namely:
* abstract_arrays
* ad_util
* api_util
* interpreters/partial_eval
* lax_reference
PiperOrigin-RevId: 520618715
2023-03-30 06:12:45 -07:00
Matthew Johnson
81de5b7a0d improve pmap in_axes/out_axes pytree prefix error messages 2023-03-29 16:56:40 -07:00
Peter Hawkins
3135fbcd7f [JAX] Delete _DeviceArray and DeviceArray.
PiperOrigin-RevId: 520453090
2023-03-29 15:07:14 -07:00
Yash Katariya
830cd9fd98 Delete _single_device_array_from_buf since everything from JAX is an Array
PiperOrigin-RevId: 520418231
2023-03-29 12:59:12 -07:00
Jake VanderPlas
5759bf05df jnp.finfo: add missing properties 2023-03-29 11:23:51 -07:00
Peter Hawkins
f48dbf039e Use cast instead of pytype suppression in setops.py.
pytype cannot tell from the type signature that unique() returns an Array, not a tuple. Add a cast to help it along.

It's possible that a future use of @overload on the definition of jnp.unique might help.

PiperOrigin-RevId: 520389675
2023-03-29 11:17:00 -07:00
jiayaobo
924894e85c add geometric random gen
add geom random

add geom random

add geom random

add geom random
2023-03-30 02:08:04 +08:00
Roy Frostig
c8a7d5990d fix custom_jvp check for tracers in arguments marked nondiff_argnums
PiperOrigin-RevId: 520379098
2023-03-29 10:42:06 -07:00
jax authors
cfa330b6fa Merge pull request #15283 from JiaYaobo:fix_wald_doc
PiperOrigin-RevId: 520364879
2023-03-29 10:01:27 -07:00
Yash Katariya
fbc05ee5ac Remove global_arg_shapes from pmap since it was only used for sharded_jit and sharded_jit was removed from JAX a long time ago
PiperOrigin-RevId: 520356179
2023-03-29 09:23:22 -07:00
jax authors
a964ae7fac Internal Code Change
PiperOrigin-RevId: 520341781
2023-03-29 08:23:56 -07:00
jiayaobo
3a4d0b3552 remove scale in wald docstring 2023-03-29 11:39:49 +08:00
Peter Hawkins
c2d6fcc0e6 Split core.py and several files in an SCC with it into a separate Bazel build target.
PiperOrigin-RevId: 520192610
2023-03-28 18:31:13 -07:00
jax authors
8c4fed6410 Merge pull request #15270 from skye:pjrt_c_api
PiperOrigin-RevId: 520156646
2023-03-28 15:48:27 -07:00
Skye Wanderman-Milne
473d1c3685 Turn on PJRT C API by default.
I forgot that the default setting is actually in jaxlib:
fbe9a80fdb/xla/python/xla_client.py (L135)

To be able to make this change as a jax-only release, I manually set
the env var on Cloud TPU if it isn't already set.
2023-03-28 15:28:13 -07:00
Rebecca Chen
5ae2e79d43 Silence some pytype errors.
PiperOrigin-RevId: 520150523
2023-03-28 15:24:48 -07:00
jax authors
4061bbbbc2 Merge pull request #15269 from skye:min_jaxlib_version
PiperOrigin-RevId: 520127548
2023-03-28 14:02:27 -07:00
Skye Wanderman-Milne
00acf459c6 Bump minimum jaxlib version from 0.4.6 to 0.4.7.
Also removes a bunch of dead version guards (0.4.7 has
xla_extension_version 144 and mlir_api_version 47)
2023-03-28 13:43:01 -07:00
jax authors
014033b75e Merge pull request #15266 from mehdiataei:mehdiataei-patch-1
PiperOrigin-RevId: 520111030
2023-03-28 13:09:56 -07:00
jax authors
bbec461c8b Merge pull request #15263 from jakevdp:deprecations
PiperOrigin-RevId: 520110559
2023-03-28 13:02:32 -07:00
Jake VanderPlas
fc47137ca8 Add deprecation warnings for several top-level jax imports 2023-03-28 12:40:59 -07:00
mehdiataei
8d090a01d0
Fixed spelling error in msgs 2023-03-28 14:48:18 -04:00
Yash Katariya
7442faa715 Remove MeshPspecSharding since it has been more than 3 months since it was deprecated (Nov 2, 2022). The replacement name is NamedSharding.
PiperOrigin-RevId: 520072687
2023-03-28 10:47:42 -07:00
Yash Katariya
97c8ce31ed Deprecate FROM_GDA and remove its support from pjit's code since jax.Array inside pjit has sharding inference capabilities by default.
PiperOrigin-RevId: 520067392
2023-03-28 10:33:00 -07:00
Colin Gaffney
6ace6667dd Set coordinator address to allow it to later be used to initialize OCDBT coordinator server. Allow user to pass ts.Context when serializing or deserializing.
PiperOrigin-RevId: 520064049
2023-03-28 10:25:07 -07:00
George Necula
2ac2dc65b1 Remove jax2tf experimental_native_lowering.
Users should use native_serialization.

PiperOrigin-RevId: 520063928
2023-03-28 10:17:58 -07:00
George Necula
a1538c73b4 [shape_poly] Refactor the computation of the dimension variables in native serialization
Currently, JAX native serialization produces a module whose main function
takes additional arguments for the values of the dimension variables. These
values are then resolved in the XlaCallModule based on a dim_args_spec
parameter.

We move the code that computes the dimension variables from XlaCallModule to
jax_export following pretty much the same technique. This simplifies
XlaCallModule and especially its API (the dim_args_spec). So far this
is just a refactoring with no semantic changes, but this will allow us
to improve the support for dimension variables that occur in linear
polynomials, e.g., "2*b" rather than just "b".
2023-03-28 12:51:48 +02:00
jax authors
4533578bba Merge pull request #15206 from jakevdp:expm-batch
PiperOrigin-RevId: 519883991
2023-03-27 18:21:27 -07:00
jax authors
03d1442c4f Merge pull request #15241 from jakevdp:instance-check
PiperOrigin-RevId: 519868384
2023-03-27 17:06:00 -07:00
jax authors
6f39237bc0 Merge pull request #15243 from jakevdp:coo-sort-warning
PiperOrigin-RevId: 519864866
2023-03-27 16:50:24 -07:00
Jake VanderPlas
ad0fc8979b jax.scipy.linalg.expm: support batched inputs 2023-03-27 16:39:48 -07:00
Yash Katariya
670fba3a91 Finish jax and jaxlib 0.4.7 release
PiperOrigin-RevId: 519839723
2023-03-27 15:06:38 -07:00
Peter Hawkins
6cc1bf54a1 Move jax.interpreters.partial_eval to jax._src.interpreters.partial_eval.
Also fix up some other internal imports of jax.interpreters.* to use jax._src.interpreters.

PiperOrigin-RevId: 519813664
2023-03-27 13:30:47 -07:00
Yash Katariya
e21aee18a8 Add deprecation warning for FROM_GDA usage since that argument is not required anymore.
PiperOrigin-RevId: 519781715
2023-03-27 11:33:11 -07:00
Sharad Vikram
3c3fa042e3 Copy seq_lengths before creating descriptor
PiperOrigin-RevId: 519771897
2023-03-27 10:59:44 -07:00
Peter Hawkins
88c2898e36 Use pytype_strict_library() in Bazel build rules.
PiperOrigin-RevId: 519757928
2023-03-27 10:16:08 -07:00
Jake VanderPlas
392bd93e4d [sparse] fix coo efficiency warning 2023-03-27 10:15:43 -07:00
Jake VanderPlas
ed9fa1342b jax.typing: recommend instance check in Python 3.10 or newer 2023-03-27 10:01:28 -07:00
jax authors
d473e86912 Merge pull request #13008 from hawkinsp:pipcuda
PiperOrigin-RevId: 519740461
2023-03-27 09:14:24 -07:00
George Necula
befb449f05 [shape_poly] Fixed bug with dimension variables in unused args
JAX will aggressively drop module input arguments if they are not
used. This can interfere with shape polymorphism, because it may
result in dropping arguments from which we need to derive the
values of shape variables.

We fix this for now by disabling dropping arguments if there
are dimension variables in the arguments shapes. A more precise
technique would be to force keeping only of arguments that we
need for deriving the dimension variables. However, that would be
a much more involved change, for an uncertain benefit.
2023-03-27 13:37:39 +02:00
George Necula
99facbab2a [jax2tf] Turn an error into a warning with native serialization
We want to allow using native_serialization_platforms even if the native_serialization is False. This is useful for code that is runnable with and without native serialization.

PiperOrigin-RevId: 519649827
2023-03-27 01:24:56 -07:00
Peter Hawkins
b62f114524 Add support for using pip-installed CUDA wheels.
Add a currently undocumented jax[cuda11_pip] and jax[cuda12_pip] that depend on the pip CUDA wheels.
Add a currently undocumented jax[cuda11_local] and jax[cuda12_local] that avoid the CUDA wheel dependency.
2023-03-26 12:35:00 +00:00
Peter Hawkins
ec427f2c95 Split dtype argument from other arguments in special functions.
This helps pytype to determine that the arguments are of different kinds, preventing type errors.

PiperOrigin-RevId: 519401250
2023-03-25 11:41:14 -07:00
Yash Katariya
a5d308542e Add src argument to device_put as an experimental arg
PiperOrigin-RevId: 519308082
2023-03-24 21:10:26 -07:00
Anish Tondwalkar
6842e98ca1 Migrate regularized_incomplete_beta_p off xla_fallback
PiperOrigin-RevId: 519244597
2023-03-24 14:53:20 -07:00