13532 Commits

Author SHA1 Message Date
Jake VanderPlas
2f27d516d7 [typing] annotate next part of lax_numpy.py 2022-10-25 12:36:26 -07:00
Yash Katariya
cf6b5097d0 Remove pytest_benchmark for test-requirements.txt and move the benchmark file which was using that package to use google_benchmark.
PiperOrigin-RevId: 483736267
2022-10-25 11:59:32 -07:00
jax authors
548d7f4599 Merge pull request #12976 from jakevdp:cuda-release-comments
PiperOrigin-RevId: 483734792
2022-10-25 11:41:50 -07:00
Jake VanderPlas
41b815bf33 Fix jax releases URL in setup.py comments 2022-10-25 10:39:39 -07:00
Yash Katariya
adcb0f58e8 Add __repr__ and __str__ to PmapSharding.
Fixes https://github.com/google/jax/issues/12971

PiperOrigin-RevId: 483707874
2022-10-25 10:13:02 -07:00
jax authors
387cce2d06 Merge pull request #12958 from jakevdp:fix-gradient
PiperOrigin-RevId: 483706501
2022-10-25 10:04:54 -07:00
Peter Hawkins
a852710a09 Merge CUDA and ROCM kernel code in jaxlib.
The code for both CUDA and ROCM is almost identical, so with a small shim library to handle the differences we can share almost everything.

PiperOrigin-RevId: 483666051
2022-10-25 07:23:34 -07:00
jax authors
621f06660d Merge pull request #12965 from gnecula:tf_readme
PiperOrigin-RevId: 483626158
2022-10-25 03:38:48 -07:00
George Necula
05b9c3be57 [jax2tf] Fixes uses of tf.Variable in README.md 2022-10-25 10:06:05 +03:00
Jake VanderPlas
2009e65a33 jnp.gradient: call check_arraylike on inputs & clean-up implementation 2022-10-24 15:27:33 -07:00
jax authors
70f659a24e Merge pull request #12957 from jakevdp:fix-lstsq
PiperOrigin-RevId: 483493703
2022-10-24 14:59:30 -07:00
jax authors
15b415b336 Merge pull request #12951 from jakevdp:annotate-lax-numpy
PiperOrigin-RevId: 483490497
2022-10-24 14:47:11 -07:00
Jake VanderPlas
56d42c0edf [typing] annotate next batch of lax_numpy 2022-10-24 14:21:35 -07:00
Jake VanderPlas
9ade89ea62 jnp.linalg.lstsq: handle zero-size inputs 2022-10-24 14:10:31 -07:00
jax authors
964988c968 Merge pull request #12953 from eltociear:patch-1
PiperOrigin-RevId: 483471278
2022-10-24 13:34:16 -07:00
Ikko Ashimine
28def736d1
Fix typo in 9419-jax-versioning.md
overriden -> overridden
2022-10-25 03:26:48 +09:00
jax authors
b892108043 Merge pull request #12950 from jakevdp:fix-ci-error
PiperOrigin-RevId: 483425722
2022-10-24 10:48:43 -07:00
jax authors
8f2f9f4563 Merge pull request #12646 from adrn:truncnorm
PiperOrigin-RevId: 483425197
2022-10-24 10:41:51 -07:00
Peter Hawkins
894093c0fb Move jaxlib cpu kernels under jaxlib/cpu/.
No functional changes intended.

PiperOrigin-RevId: 483413031
2022-10-24 10:02:56 -07:00
Jake VanderPlas
48e680c839 CI: avoid raising error when wrapped function is None 2022-10-24 08:57:53 -07:00
jax authors
67fa7c27d5 Typo fix.
PiperOrigin-RevId: 483380789
2022-10-24 07:53:50 -07:00
Adrian Price-Whelan
5784d61048 implement truncnorm in jax.scipy.stats
fix some shape and type issues

import into namespace

imports into non-_src library

working logpdf test

cleanup

working tests for cdf and sf after fixing select

relax need for x to be in (a, b)

ensure behavior with invalid input matches scipy

remove enforcing valid parameters in tests

added truncnorm to docs

whoops alphabetical

fix linter error

fix circular import issue
2022-10-22 15:48:20 -04:00
Xin Zhou
b07c586565 [mhlo] Use 11 out of 12 new shared type inferences from StableHLO.
The shape function of DotGeneralOp can't be integrated into MHLO yet: the shape function only predicts return shape but not able to predict element type. However, the current python binding infra will generate the constructor __init__() without the `return` as the first arg, which assumes the shape function can provide a fully inferred type (including an accurate element type). This leads to "inferred type does not match actual result type" errors in JAX. This needs a future solution.

This CL is the corresponding change with https://github.com/openxla/stablehlo/pull/269

Related Python __init__() interface changes (used by JAX):
batch_norm_grad:      not used by JAX
batch_norm_inference: not used by JAX
batch_norm_training:  not used by JAX
case:                 no change*
dot_general:          open new b/253644255 to track the issue
if:                   no change*
map:                  no change*
reduce:               no change*
reduce_window:        no change*
sort:                 no change*
triangular_solve:     updated in `linalg.py`
while:                no change*

no change*: the signature of __init()__ for the op is not changed because of existence of regions https://github.com/llvm/llvm-project/blob/main/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp#L577

PiperOrigin-RevId: 482951512
2022-10-21 20:34:04 -07:00
jax authors
3be5ab218a Allow calling initialize_cache a second time if the path is the same.
PiperOrigin-RevId: 482945880
2022-10-21 19:54:09 -07:00
Yash Katariya
9956ad2f89 Add more pjit tests and make some tests go via actual computations rather than trivial computation.
PiperOrigin-RevId: 482919649
2022-10-21 16:53:53 -07:00
jax authors
a4e366394b Merge pull request #12921 from jakevdp:lax-numpy-dtypes
PiperOrigin-RevId: 482905407
2022-10-21 15:42:40 -07:00
jax authors
64e996e73b Merge pull request #12925 from jakevdp:annotate-lax
PiperOrigin-RevId: 482903592
2022-10-21 15:36:22 -07:00
jax authors
4acc293ffb Merge pull request #12923 from jakevdp:nperseg-test
PiperOrigin-RevId: 482902569
2022-10-21 15:30:04 -07:00
Tianjian Lu
e219d55c36 Roll-back #12892 because CUSPARSE_SPMV_COO_ALG2 is not available in CUDA 11.1
PiperOrigin-RevId: 482897448
2022-10-21 15:06:17 -07:00
Qiao Zhang
4e8fbd0239 Add delete method to GlobalDeviceArray and ShardedBuffer.
This ensures all existing JAX buffer types have a `delete` method that can be used to free device buffer allocation eagerly.

User code sometimes have lingering python refs due to cyclic deps and other reasons, yet users may know for sure that certain arrays will no longer be used after a certain point. Calling `foo_array.delete()` for DeviceArray/ShardedDeviceArray/GlobalDeviceArray/Array allows users to force free the device side allocation to minimize device memory usage.

PiperOrigin-RevId: 482892157
2022-10-21 14:42:32 -07:00
Jake VanderPlas
ca7d05f4f1 [typing] fix incorrect type annotation on lax.argmax/argmin 2022-10-21 14:37:59 -07:00
Yash Katariya
37a015690d First pass at adding GetOutputShardings and GetParameterShardings on PjRTExecutable.
PiperOrigin-RevId: 482878289
2022-10-21 13:43:52 -07:00
jax authors
e297b13ef2 Merge pull request #12899 from jakevdp:dead-code
PiperOrigin-RevId: 482867884
2022-10-21 13:02:47 -07:00
Jake VanderPlas
4714a5cc8f Add regression test for #12920 2022-10-21 12:52:32 -07:00
jax authors
045d6e17b0 Merge pull request #12920 from yilei:nameerror
PiperOrigin-RevId: 482864751
2022-10-21 12:49:26 -07:00
Jake VanderPlas
97b17af5be [typing] add type annotations to the first several lax_numpy functions 2022-10-21 11:59:53 -07:00
Yilei Yang
d63d038eeb Flatten the if/else block. 2022-10-21 11:27:45 -07:00
Yilei "Dolee" Yang
7c1bf0e7cd
Fix an NameError caused by #12754 2022-10-21 10:22:36 -07:00
jax authors
503679d636 Merge pull request #12904 from froystig:dlpack64
PiperOrigin-RevId: 482798461
2022-10-21 08:23:18 -07:00
Benjamin Kramer
dd04953361 [MLIR] Don't rely on hardcoded -1 for dynamic axis sizes
The magic number might change, use an accessor to get it.

PiperOrigin-RevId: 482796475
2022-10-21 08:13:02 -07:00
jax authors
2228efe277 Merge pull request #12876 from mattjj:djax-vmap3
PiperOrigin-RevId: 482693137
2022-10-20 22:39:51 -07:00
Matthew Johnson
8e8ae8441f fix 2022-10-20 22:23:29 -07:00
Matthew Johnson
f76fc010e4 put back some unsafe_maps 2022-10-20 21:56:00 -07:00
jax authors
8030b49b23 Merge pull request #12912 from mattjj:issue12909
PiperOrigin-RevId: 482685295
2022-10-20 21:44:09 -07:00
Matthew Johnson
6a3d2a0dde update docs to point to jax.nn.standardize
Fixes #12909
2022-10-20 21:25:37 -07:00
jax authors
2d45d8b3a6 Merge pull request #12800 from mattjj:ayaka
PiperOrigin-RevId: 482676604
2022-10-20 20:40:08 -07:00
Matthew Johnson
60b236cff0 improve (and shorten!) pmap error messages about inconsistent axis sizes 2022-10-20 18:31:40 -07:00
jax authors
1b5294e597 Merge pull request #12906 from jakevdp:stats-mode-doc
PiperOrigin-RevId: 482632489
2022-10-20 16:37:56 -07:00
John QiangZhang
408953bc3e fix jax2tf readme typo
PiperOrigin-RevId: 482625385
2022-10-20 16:09:00 -07:00
jax authors
6e4b13571b Merge pull request #12896 from jakevdp:unused-imports
PiperOrigin-RevId: 482623200
2022-10-20 16:01:02 -07:00