5739 Commits

Author SHA1 Message Date
Yash Katariya
e21aee18a8 Add deprecation warning for FROM_GDA usage since that argument is not required anymore.
PiperOrigin-RevId: 519781715
2023-03-27 11:33:11 -07:00
Sharad Vikram
3c3fa042e3 Copy seq_lengths before creating descriptor
PiperOrigin-RevId: 519771897
2023-03-27 10:59:44 -07:00
Yash Katariya
41695cc78c Temporarily fix the compilation cache test which is failing on latest jaxlib release
PiperOrigin-RevId: 519745099
2023-03-27 09:37:37 -07:00
Yash Katariya
a5d308542e Add src argument to device_put as an experimental arg
PiperOrigin-RevId: 519308082
2023-03-24 21:10:26 -07:00
jax authors
c572155cc1 Merge pull request #15212 from google:pjrt_c_api_tests
PiperOrigin-RevId: 519276265
2023-03-24 17:27:40 -07:00
Skye Wanderman-Milne
ef5e4a4035 Remove 'pjrt_c_api_unimplemented' pytest mark.
Instead, we skip tests that the PJRT C API doesn't support. We had
this tag for feature development so it was easy to broadly disable,
but now we don't expect to need to do that.
2023-03-24 23:14:54 +00:00
Anish Tondwalkar
6842e98ca1 Migrate regularized_incomplete_beta_p off xla_fallback
PiperOrigin-RevId: 519244597
2023-03-24 14:53:20 -07:00
Anish Tondwalkar
ac44d2c2e3 Migrate besseli0e off xla_fallback
PiperOrigin-RevId: 519241252
2023-03-24 14:39:40 -07:00
Peter Hawkins
6ed66ada0f Delete remote TPU support.
TPU VMs are the only supported way to use TPUs as of JAX 0.4.0.

PiperOrigin-RevId: 519211267
2023-03-24 12:33:33 -07:00
Parker Schuh
21541e60b1 Guard ArrayImpl checks by xla_extension_version.
PiperOrigin-RevId: 519191714
2023-03-24 11:15:36 -07:00
Yash Katariya
bc231ee0bf After the SPMD bug fix, always take the _rewriting_take route for getitem instead of bouncing to host.
PiperOrigin-RevId: 519170785
2023-03-24 10:00:41 -07:00
Anish Tondwalkar
8c75e27f67 Migrate random_gamma_grad off xla_fallback
PiperOrigin-RevId: 519154537
2023-03-24 08:49:40 -07:00
Anish Tondwalkar
8d1d522618 Migrate igamma_grad_a_p off xla_fallback
PiperOrigin-RevId: 519148548
2023-03-24 08:21:22 -07:00
Anish Tondwalkar
4a9b09485e Migrate igammac_p off xla_fallback path
It is now decomposed into stablehlo ops.

PiperOrigin-RevId: 519122775
2023-03-24 05:58:38 -07:00
jax authors
32e712864c Merge pull request #15192 from mattjj:issue15190
PiperOrigin-RevId: 519037959
2023-03-23 20:48:33 -07:00
jax authors
1982a113d6 Merge pull request #15187 from mattjj:djax-revival
PiperOrigin-RevId: 519036576
2023-03-23 20:38:01 -07:00
Matthew Johnson
7743fcd758 [dynamic-shapes] make dynamic shape staging-to-jaxpr work with pjit 2023-03-23 20:20:01 -07:00
Matthew Johnson
793387e496 fix jax.Array.round()
fixes #15190
2023-03-23 20:16:23 -07:00
Skye Wanderman-Milne
4cb3b011a0 Remove PJRT C API bypass.
Now that all functionality needed by frameworks is implemented, let's
remove the possibility of not noticing missing functionality due to
the bypass.

PiperOrigin-RevId: 519018438
2023-03-23 18:39:14 -07:00
Peter Hawkins
b7375b316b Increase minimum NumPy version to 1.21.
Also increase minimum SciPy version to 1.7, which was released just before NumPy 1.21.
2023-03-23 21:15:10 -04:00
jax authors
e9bc7ee866 Merge pull request #15184 from jakevdp:move-median
PiperOrigin-RevId: 519003606
2023-03-23 17:14:50 -07:00
Jake VanderPlas
6f8885a0c2 lax_numpy: move quantile-based functions to reductions.py 2023-03-23 16:39:20 -07:00
Anish Tondwalkar
f981243af5 Migrate igamma_p off xla_fallback
We decompose it into a series or a call to igammac.

PiperOrigin-RevId: 518993077
2023-03-23 16:26:59 -07:00
John QiangZhang
171b22dbbc Add padding option "SAME_LOWER" for ticket https://github.com/google/jax/pull/14990
PiperOrigin-RevId: 518984018
2023-03-23 15:50:16 -07:00
Matthew Johnson
ba2ff519ca improve scan error messages 2023-03-23 14:53:05 -07:00
Yash Katariya
a9e48af260 Deprecated xla_call_p since it has been replaced with pjit.pjit_p
PiperOrigin-RevId: 518921538
2023-03-23 11:44:42 -07:00
jax authors
383cf41848 Merge pull request #14937 from b0nce:fix-stats
PiperOrigin-RevId: 518888600
2023-03-23 10:01:54 -07:00
jax authors
54e8101f00 Merge pull request #15123 from jakevdp:fix-mean-large-dims
PiperOrigin-RevId: 518852476
2023-03-23 07:23:20 -07:00
jax authors
e39578cd73 Merge pull request #15154 from mattjj:pjit-typecheck
PiperOrigin-RevId: 518717095
2023-03-22 17:31:59 -07:00
Parker Schuh
484eb26d2a Redefine compile_and_serialize as serialize(lowered.compile()).
This has the downside of keeping around the UnloadedMeshComputation,
but it makes the serialize() API easier to understand.

PiperOrigin-RevId: 518715469
2023-03-22 17:23:19 -07:00
Matthew Johnson
268456ef54 enable pjit recursive typechecking
Give pjit_p a custom typecheck rule, which basically just calls the
core._check_call utility (which was made for xla_call_p and core.call_p).

This revealed the need for a slight generalization of the custom_typecheck rule
signature, for better "context-aware" printing of jaxpr type errors: the rules
should have a `ctx_factory` first argument. **The reason this PR touches so
many files is just that it makes the trivial tweaks to all existing typecheck
rules to accomodate that new signature.** I didn't adapt any other higher-order
primitives' rules to actually use the context, but presumably errors for HOPs
like scan would be improved by using it. Follow-up work!

It's key that core._check_call works with dynamic shapes; this PR is soon to be
followed by some djax+pjit PRs!
2023-03-22 16:59:22 -07:00
jax authors
5f724cf9a2 Merge pull request #15151 from mattjj:qiao-remat-print-res-info
PiperOrigin-RevId: 518709114
2023-03-22 16:53:22 -07:00
Matthew Johnson
6b4262d9f6 add experimental jax_log_checkpoint_residuals option
The main idea here is to improve tooling for knowing what residuals are being
saved and why. There's a lot more that can be done here (e.g. naming the
arguments, explaining what JVP rule produced these residuals, explaining what
consumed them, etc) but this is a start.

Co-authored-by: Qiao Zhang <zhangqiaorjc@google.com>
2023-03-22 16:26:56 -07:00
Sharad Vikram
12761f966e Add print statement to help debug spurious test failure 2023-03-22 21:47:15 +00:00
jax authors
dd2ecf4bb5 Merge pull request #15085 from mattjj:arg-info-in-mlir-5
PiperOrigin-RevId: 518642948
2023-03-22 12:31:35 -07:00
Jake VanderPlas
398532b32c jax.random: remove scale from wald function 2023-03-22 10:44:50 -07:00
Peter Hawkins
64e1f5fe3d Revert: custom_vjp symbolic zeros support
PiperOrigin-RevId: 518597609
2023-03-22 09:56:09 -07:00
jax authors
00e6c73b68 Merge pull request #15114 from JiaYaobo:add_wald_random
PiperOrigin-RevId: 518592428
2023-03-22 09:37:44 -07:00
jiayaobo
f7a14d65d2 add wald random generator
add wald to random.py
2023-03-22 11:06:59 +08:00
jax authors
1703f096b5 Merge pull request #15088 from froystig:custom-vjp-symbolic-zeros
PiperOrigin-RevId: 518433198
2023-03-21 18:05:33 -07:00
Jieying Luo
b403c2a083 [PJRT C API] Add parsing PJRT client create options from json file.
PiperOrigin-RevId: 518418760
2023-03-21 16:57:34 -07:00
Yash Katariya
9a0de29114 Remove the config.jax_array and jax_jit_pjit_api_merge flag usage since those are always True
PiperOrigin-RevId: 518368963
2023-03-21 13:42:26 -07:00
jax authors
1b141ed9ca Merge pull request #15120 from nouiz:inspect_array_sharding
PiperOrigin-RevId: 518307962
2023-03-21 10:12:00 -07:00
Jake VanderPlas
ed0170c8c4 jnp.mean: fix incorrect return value for large arrays 2023-03-21 09:36:29 -07:00
Yash Katariya
b5c9c0f47e Raise a better error message when there is a device assignment mismatch via the apply_primitive route.
PiperOrigin-RevId: 518282464
2023-03-21 08:40:42 -07:00
Frederic Bastien
42e9753431 Fix inspect_array_sharding with grad. 2023-03-21 07:58:27 -07:00
Roy Frostig
ac7491ced0 custom_vjp symbolic zeros support 2023-03-21 14:14:35 +00:00
Peter Hawkins
e0453add22 Mark jax.interpreters.pxla.ShardedDeviceArray as deprecated.
PiperOrigin-RevId: 518241326
2023-03-21 05:13:55 -07:00
Matthew Johnson
da3799959a separate register_pytree_node and register_pytree_with_keys tests 2023-03-20 20:05:47 -07:00
Misha
83b3f5b759 Fix loc and scale parameters in scipy.logistic. Add CDF and SF for several distributions. 2023-03-21 00:16:13 +01:00