Jake VanderPlas
8f72454bdf
Add internal jax.lax.asarray utility
2023-03-30 10:21:55 -07:00
Peter Hawkins
67a28ce30f
Relax test tolerances for testLogisticPpf.
...
Fixes a test failure in CI.
PiperOrigin-RevId: 520649225
2023-03-30 08:41:56 -07:00
jax authors
dedfc8df75
Merge pull request #15282 from JiaYaobo:geom_random
...
PiperOrigin-RevId: 520635974
2023-03-30 07:45:19 -07:00
jax authors
1fd6e01289
Merge pull request #15287 from gnecula:tf_dim_vars
...
PiperOrigin-RevId: 520633830
2023-03-30 07:37:47 -07:00
jax authors
794769c113
Merge pull request #15302 from mattjj:pmap-pytree-prefix-errors
...
PiperOrigin-RevId: 520632081
2023-03-30 07:29:51 -07:00
jax authors
0a2e383eaf
Merge pull request #15297 from jakevdp:finfo-props
...
PiperOrigin-RevId: 520632058
2023-03-30 07:22:28 -07:00
George Necula
081b86b82a
[shape_poly] Improved computation of dimension variables for native serialization
...
Previously for native serialization we could only support polymorphic_shapes
where the specification was a simple dimension variable. E.g., we could not
handle a specification where `polymorphic_shapes="2*b"` because there was
no way to recover the value of `b` from the actual shape. (For non-native
serialization we were supporting some limited equation solving.)
The above is important, e.g., for the gradient of functions like
`jnp.concatenate([x, x])`, where the output shape if `2 *b`.
This is possible because in #15258 we have brought the computation
of the dimension variables into jax_export.
What we do here is to even out the support for native serialization to have
the same power as the non-native one. We do this by reusing the
same `shape_poly.prepare_dim_var_env` that we use for non-native
serialization.
After we land this, we will refactor the shape environment to be cleaner.
2023-03-30 15:51:24 +02:00
Peter Hawkins
47177e1417
Split more targets out the main JAX Bazel target.
...
Namely:
* abstract_arrays
* ad_util
* api_util
* interpreters/partial_eval
* lax_reference
PiperOrigin-RevId: 520618715
2023-03-30 06:12:45 -07:00
Matthew Johnson
81de5b7a0d
improve pmap in_axes/out_axes pytree prefix error messages
2023-03-29 16:56:40 -07:00
Peter Hawkins
3135fbcd7f
[JAX] Delete _DeviceArray and DeviceArray.
...
PiperOrigin-RevId: 520453090
2023-03-29 15:07:14 -07:00
jax authors
6ce5eeb85f
Merge pull request #15299 from jakevdp:fix-docs
...
PiperOrigin-RevId: 520439293
2023-03-29 14:14:41 -07:00
Jake VanderPlas
8562a8d7bc
DOC: pin pydata-sphinx-theme to fix incompatibility
2023-03-29 13:55:32 -07:00
Yash Katariya
830cd9fd98
Delete _single_device_array_from_buf
since everything from JAX is an Array
...
PiperOrigin-RevId: 520418231
2023-03-29 12:59:12 -07:00
Jake VanderPlas
5759bf05df
jnp.finfo: add missing properties
2023-03-29 11:23:51 -07:00
Peter Hawkins
f48dbf039e
Use cast instead of pytype suppression in setops.py.
...
pytype cannot tell from the type signature that unique() returns an Array, not a tuple. Add a cast to help it along.
It's possible that a future use of @overload on the definition of jnp.unique might help.
PiperOrigin-RevId: 520389675
2023-03-29 11:17:00 -07:00
jiayaobo
924894e85c
add geometric random gen
...
add geom random
add geom random
add geom random
add geom random
2023-03-30 02:08:04 +08:00
Roy Frostig
c8a7d5990d
fix custom_jvp
check for tracers in arguments marked nondiff_argnums
...
PiperOrigin-RevId: 520379098
2023-03-29 10:42:06 -07:00
Jake VanderPlas
f282c251d4
Add minimal pyproject.toml specifying build system
...
Replaces #15274 , Fixes #15256
PiperOrigin-RevId: 520367622
jax-v0.4.8
2023-03-29 10:08:30 -07:00
jax authors
cfa330b6fa
Merge pull request #15283 from JiaYaobo:fix_wald_doc
...
PiperOrigin-RevId: 520364879
2023-03-29 10:01:27 -07:00
jax authors
2d94f76ca3
Merge pull request #15278 from hawkinsp:cudainstall
...
PiperOrigin-RevId: 520364354
2023-03-29 09:53:58 -07:00
Yash Katariya
fbc05ee5ac
Remove global_arg_shapes from pmap since it was only used for sharded_jit and sharded_jit was removed from JAX a long time ago
...
PiperOrigin-RevId: 520356179
2023-03-29 09:23:22 -07:00
jax authors
a964ae7fac
Internal Code Change
...
PiperOrigin-RevId: 520341781
2023-03-29 08:23:56 -07:00
jax authors
7200d07db5
Merge pull request #15286 from hawkinsp:testjobs
...
PiperOrigin-RevId: 520319910
2023-03-29 06:39:04 -07:00
Peter Hawkins
d9b0f3cd6f
Recommend --local_test_jobs in bazel test command line on GPU.
2023-03-29 09:28:53 -04:00
jax authors
07fc0222a1
Merge pull request #15279 from hawkinsp:versions
...
PiperOrigin-RevId: 520307157
2023-03-29 05:25:50 -07:00
jiayaobo
3a4d0b3552
remove scale in wald docstring
2023-03-29 11:39:49 +08:00
Peter Hawkins
705b5cc000
Add version constraints to CUDA pip wheel dependencies.
...
Fixes https://github.com/google/jax/issues/15267
2023-03-28 21:55:32 -04:00
Peter Hawkins
775f404f31
Update the CUDA installation instructions.
2023-03-28 21:46:07 -04:00
Peter Hawkins
c2d6fcc0e6
Split core.py and several files in an SCC with it into a separate Bazel build target.
...
PiperOrigin-RevId: 520192610
2023-03-28 18:31:13 -07:00
jax authors
8c4fed6410
Merge pull request #15270 from skye:pjrt_c_api
...
PiperOrigin-RevId: 520156646
2023-03-28 15:48:27 -07:00
Skye Wanderman-Milne
473d1c3685
Turn on PJRT C API by default.
...
I forgot that the default setting is actually in jaxlib:
fbe9a80fdb/xla/python/xla_client.py (L135)
To be able to make this change as a jax-only release, I manually set
the env var on Cloud TPU if it isn't already set.
2023-03-28 15:28:13 -07:00
Rebecca Chen
5ae2e79d43
Silence some pytype errors.
...
PiperOrigin-RevId: 520150523
2023-03-28 15:24:48 -07:00
jax authors
4061bbbbc2
Merge pull request #15269 from skye:min_jaxlib_version
...
PiperOrigin-RevId: 520127548
2023-03-28 14:02:27 -07:00
Skye Wanderman-Milne
00acf459c6
Bump minimum jaxlib version from 0.4.6 to 0.4.7.
...
Also removes a bunch of dead version guards (0.4.7 has
xla_extension_version 144 and mlir_api_version 47)
2023-03-28 13:43:01 -07:00
jax authors
014033b75e
Merge pull request #15266 from mehdiataei:mehdiataei-patch-1
...
PiperOrigin-RevId: 520111030
2023-03-28 13:09:56 -07:00
jax authors
bbec461c8b
Merge pull request #15263 from jakevdp:deprecations
...
PiperOrigin-RevId: 520110559
2023-03-28 13:02:32 -07:00
Jake VanderPlas
fc47137ca8
Add deprecation warnings for several top-level jax imports
2023-03-28 12:40:59 -07:00
Yash Katariya
2f105bde2d
Jax 0.4.7 has been released so assert that length of warnings is 1 in test_cache_read_warning
...
PiperOrigin-RevId: 520098757
2023-03-28 12:17:30 -07:00
mehdiataei
8d090a01d0
Fixed spelling error in msgs
2023-03-28 14:48:18 -04:00
jax authors
2fbccc8f5e
Merge pull request #15251 from jakevdp:mypy-deps
...
PiperOrigin-RevId: 520079899
2023-03-28 11:11:15 -07:00
Yash Katariya
7442faa715
Remove MeshPspecSharding since it has been more than 3 months since it was deprecated (Nov 2, 2022). The replacement name is NamedSharding.
...
PiperOrigin-RevId: 520072687
2023-03-28 10:47:42 -07:00
Yash Katariya
97c8ce31ed
Deprecate FROM_GDA and remove its support from pjit's code since jax.Array inside pjit has sharding inference capabilities by default.
...
PiperOrigin-RevId: 520067392
2023-03-28 10:33:00 -07:00
Colin Gaffney
6ace6667dd
Set coordinator address to allow it to later be used to initialize OCDBT coordinator server. Allow user to pass ts.Context when serializing or deserializing.
...
PiperOrigin-RevId: 520064049
2023-03-28 10:25:07 -07:00
George Necula
2ac2dc65b1
Remove jax2tf experimental_native_lowering.
...
Users should use native_serialization.
PiperOrigin-RevId: 520063928
2023-03-28 10:17:58 -07:00
Yash Katariya
86c0b36bfd
Remove Cuda 11.4 support. JAX from 0.4.8 release will support cuda 11.8 and cuda 12
...
PiperOrigin-RevId: 520056811
2023-03-28 09:54:36 -07:00
jax authors
4106c35e0a
Merge pull request #15258 from gnecula:dim_vars
...
PiperOrigin-RevId: 520033493
2023-03-28 08:21:00 -07:00
George Necula
a1538c73b4
[shape_poly] Refactor the computation of the dimension variables in native serialization
...
Currently, JAX native serialization produces a module whose main function
takes additional arguments for the values of the dimension variables. These
values are then resolved in the XlaCallModule based on a dim_args_spec
parameter.
We move the code that computes the dimension variables from XlaCallModule to
jax_export following pretty much the same technique. This simplifies
XlaCallModule and especially its API (the dim_args_spec). So far this
is just a refactoring with no semantic changes, but this will allow us
to improve the support for dimension variables that occur in linear
polynomials, e.g., "2*b" rather than just "b".
2023-03-28 12:51:48 +02:00
jax authors
4533578bba
Merge pull request #15206 from jakevdp:expm-batch
...
PiperOrigin-RevId: 519883991
2023-03-27 18:21:27 -07:00
jax authors
03d1442c4f
Merge pull request #15241 from jakevdp:instance-check
...
PiperOrigin-RevId: 519868384
2023-03-27 17:06:00 -07:00
jax authors
6f39237bc0
Merge pull request #15243 from jakevdp:coo-sort-warning
...
PiperOrigin-RevId: 519864866
2023-03-27 16:50:24 -07:00