16594 Commits

Author SHA1 Message Date
Jake VanderPlas
d0e75ca117 Require index update optional arguments to be passed by keyword.
Passing these keywords by position has been deprecated and has raised a warning since JAX v0.4.7 (Released 27 March 2023)

PiperOrigin-RevId: 544620172
2023-06-30 04:30:34 -07:00
Chris Jones
3f9da19c63 Add get_serialized_metadata function to retrieve metadata from op's opaque data.
PiperOrigin-RevId: 544608895
2023-06-30 03:23:28 -07:00
jax authors
2575307c04 Merge pull request #16600 from jakevdp:schur-jvp
PiperOrigin-RevId: 544603688
2023-06-30 02:58:06 -07:00
George Necula
a815a89c21 Add backwards compatibility test for Mosaic.
I set it up to use some small helper functions that we use for other JAX custom calls.

We should think what kind of tests we actually need. The boilerplate that I set up here makes sense if we plan to have more than one test. E.g., do we want to test backwards compatibility only for the calling conventions of tpu_custom_call, or also that it gives the same behavior over multiple ops?

PiperOrigin-RevId: 544602453
2023-06-30 02:49:38 -07:00
Jake VanderPlas
a329f8b947 schur: fix broken jvp rule 2023-06-30 02:30:25 -07:00
Tao Wang
16f72cf903 Add get_profiled_instructions_proto API in jax.experimental.profiler to get profiled instructions proto.
PiperOrigin-RevId: 544540912
2023-06-29 21:02:05 -07:00
Marcus Chiam
bbd3824332 Flax is changing the RNNCellBase API:
- when calling the constructor of a class, it is now required to pass in a `features` argument
- when calling the `initialize_carry` method, instead of passing in the `batch_dims` and `size`, you only have to pass in an `input_shape`

More details about the changes and how to upgrade to the new API can be found [here](https://flax--3053.org.readthedocs.build/en/3053/guides/rnncell_upgrade_guide.html).

PiperOrigin-RevId: 544461085
2023-06-29 14:21:02 -07:00
George Necula
46aa9e0b31 Copybara import of the project:
--
b07be45e8cecd492e3f269907cf4a2d5ec6a8b4d by George Necula <gcnecula@gmail.com>:

[shape_poly] Fix lowering when we have both dimension variables and tokens

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/16575 from gnecula:call_tf_poly b07be45e8cecd492e3f269907cf4a2d5ec6a8b4d
PiperOrigin-RevId: 544252624
2023-06-28 22:15:16 -07:00
jax authors
64b0962f4e Merge pull request #16581 from froystig:cond-changelog
PiperOrigin-RevId: 544158825
2023-06-28 14:18:49 -07:00
Roy Frostig
48903a382e add corner-case cond resolution fix to changelog 2023-06-28 10:09:10 -07:00
George Necula
f463437c7e Cleanup API usage related to shape polymorphism.
We import jax._src.core instead of jax.core because we need access to JAX internal symbols (core.is_constant_shape). This is in preparation for removing some symbols from the public APIs.

PiperOrigin-RevId: 544063204
2023-06-28 08:23:15 -07:00
Chris Jones
d4e2464340 [jax_triton] Expose Triton custom call callback in header file.
This allows users to register the callback from C++ when not using the default call target name.

PiperOrigin-RevId: 544029098
2023-06-28 05:32:02 -07:00
jax authors
5b698c899e Merge pull request #16502 from jakevdp:pxla-deprecation
PiperOrigin-RevId: 543952523
2023-06-27 22:40:27 -07:00
jax authors
2948b56944 Merge pull request #16567 from jakevdp:fix-gather-batching
PiperOrigin-RevId: 543951932
2023-06-27 22:32:12 -07:00
Jake VanderPlas
3f47ad367d jax.interpreters.pxla: remove deprecated functions:
- jax.interpreters.pxla.device_put
- jax.interpreters.pxla.make_sharded_device_array
2023-06-27 21:49:55 -07:00
Jake VanderPlas
18bbc96279 Fix integer overflow in gather batching rule 2023-06-27 21:45:45 -07:00
Roy Frostig
14f32653a1 resolve conditionals to default "shared operand form" more often
If both the second and third operand of a `lax.cond` call are callable, then
resolve it as a new-style (default) conditional, where both branches act on the
same operands.

This changes the behavior of five-argument `lax.cond` calls. It is a breaking
change for callers using the old-style `cond` calling convention (`pred`,
`true_arg`, `true_fn`, `false_arg`, `false_fn`) with a callable `true_arg`.

PiperOrigin-RevId: 543912445
2023-06-27 18:49:16 -07:00
jax authors
94c3e45d03 Merge pull request #16566 from jakevdp:stats-binom
PiperOrigin-RevId: 543797338
2023-06-27 11:07:35 -07:00
Jake VanderPlas
30d1a8a80f Add jax.scipy.stats.binom 2023-06-27 03:41:38 -07:00
jax authors
6bc74d2a98 Merge pull request #16562 from froystig:outline-split
PiperOrigin-RevId: 543659641
2023-06-27 00:34:37 -07:00
George Necula
cb42fae810 [shape_poly] Shape polymorphism support for approx_top_k
PiperOrigin-RevId: 543633818
2023-06-26 22:02:41 -07:00
Yash Katariya
744a64fce6 Make sharding on ShapeDtypeStruct a property that always exists. The previous behavior was it only existed if sharding was not None.
sharding=None means that JAX is free to choose whatever sharding it wants. As it stands, jax will choose to mark the input as replicated but JAX reserves the right to change that as it sees fit.
PiperOrigin-RevId: 543630595
2023-06-26 21:46:50 -07:00
Parker Schuh
819f731e8d jax.lax.collapse now takes Nones for stop_dimension.
PiperOrigin-RevId: 543598626
2023-06-26 18:30:34 -07:00
Yash Katariya
c632cace1e Raise an error if a user passes None to host_local_array_to_global_array or global_array_to_host_local_array
PiperOrigin-RevId: 543596009
2023-06-26 18:15:43 -07:00
George Necula
c6a60054b9 [shape_poly] linalg.schur: shape polymorphism with native lowering on CPU
PiperOrigin-RevId: 543533821
2023-06-26 13:59:01 -07:00
George Necula
a91412e1e7 [shape_poly] linalg.triangular_solve: shape polymorphism with native lowering on CPU
PiperOrigin-RevId: 543506845
2023-06-26 12:13:12 -07:00
Roy Frostig
690e626312 outline jitted threefry split and fold_in subroutines
We may want to continue to inline these in Jaxpr, but it's useful to
outline them in HLO for visualization and debugging.
2023-06-26 11:52:55 -07:00
George Necula
ea0e50f765 [shape_poly] Refactor support for dynamic shapes for linalg.eig and linalg.eigh
The support for dynamic shapes for linalg.eig and linalg.eigh has been added
before we added the helper function `mk_result_types_and_shapes`, which has
been used for all other linalg primitives. Here we refactor linalg.eig and
linalg.eigh support to use these helper functions and follow the same style
as for other linalg primitives.

PiperOrigin-RevId: 543495381
2023-06-26 11:31:31 -07:00
jax authors
be0f4ce0d2 Merge pull request #16554 from gnecula:poly_refactor
PiperOrigin-RevId: 543467461
2023-06-26 10:06:49 -07:00
jax authors
dad18c1136 Merge pull request #16467 from gnecula:safety_checks
PiperOrigin-RevId: 543467440
2023-06-26 09:58:43 -07:00
George Necula
aadcec2b1b [shape_poly] Refactor some older support for shape polymorphism for linalg.
Moving some helper functions from linalg.py to hlo_helpers.py, so that we
can reuse them for more custom calls, including those in gpu_solver.

Also renamed some helper functions, e.g., _hlo_s32 -> hlo_s32, and ir_constant_i32 -> hlo_s32.

PiperOrigin-RevId: 543448560
2023-06-26 08:40:22 -07:00
Jieying Luo
21588a30a9 [PJRT C API] Add related C type definitions for key value get/put callback, as well as conversion between C and cpp types.
This is similar to how send/receive callback are implemented.

Update make_c_api_client to take key value get/put callback generated from distributed client, and optiosn of node_id and num_nodes.

PiperOrigin-RevId: 543441403
2023-06-26 08:08:52 -07:00
George Necula
c800ef8b6c [jax2tf] Improve custom call safety check to detect mhlo.custom_call 2023-06-25 19:48:55 +02:00
George Necula
2299f05b8b [shape_poly] Cleanup the evaluation of dynamic shapes
Previously, we used the following pattern to generate the 1D
tensors representing dynamic shapes:

```
mlir.shape_tensor(mlir.eval_dynamic_shape(ctx, shape))
```

Now we write:
```
mlir.eval_dynamic_shape_as_tensor(ctx, shape)
```
2023-06-25 18:20:50 +02:00
Parker Schuh
feced360f0 Make the default driver in serialization be a global constant.
PiperOrigin-RevId: 543008650
2023-06-23 18:40:31 -07:00
Peter Hawkins
816ba91263 Use lower-case PEP 585 names for types.
Issue https://github.com/google/jax/issues/16537

PiperOrigin-RevId: 542969282
2023-06-23 15:12:14 -07:00
jax authors
f67acee129 Merge pull request #16430 from jakevdp:bool-error
PiperOrigin-RevId: 542951181
2023-06-23 14:00:12 -07:00
jax authors
9d480e2e07 Merge pull request #16533 from hawkinsp:winreadme
PiperOrigin-RevId: 542949760
2023-06-23 13:52:09 -07:00
Peter Hawkins
a1de687382 Update README to mention Windows support. 2023-06-23 16:17:18 -04:00
jax authors
01a16f5914 Merge pull request #16487 from jakevdp:convolve-dtype
PiperOrigin-RevId: 542929304
2023-06-23 12:32:36 -07:00
jax authors
1e1992ae2b Merge pull request #16535 from jakevdp:numpy-122
PiperOrigin-RevId: 542900706
2023-06-23 10:47:24 -07:00
Jake VanderPlas
ad35702934 Drop support for numpy 1.21
This is in accordance with NEP 29 and https://jax.readthedocs.io/en/latest/deprecation.html
2023-06-23 10:28:26 -07:00
jax authors
63415a9184 Merge pull request #16386 from axch:ragged-einsum
PiperOrigin-RevId: 542887557
2023-06-23 10:00:07 -07:00
Peter Hawkins
0adfafe293 Relax test tolerances.
This makes the tests pass on CPU with a slightly different seed (+ 1).

PiperOrigin-RevId: 542877795
2023-06-23 09:22:11 -07:00
Peter Hawkins
bfa113ba60 Remove references to Python 3.8.
Remove the old build scripts/Dockerfile, since they are unused and broken.

PiperOrigin-RevId: 542870354
2023-06-23 08:48:57 -07:00
Alexey Radul
bb7d918429 Type annotations. 2023-06-23 10:56:27 -04:00
George Necula
bbc6f30693 [shape_poly] linalg.lu: for shape polymorphism for native serialization on CPU.
We support polymorphism only on the batch sizes for now. The
jaxlib and C++ code support full dynamic shapes.

Also added backwards compatibility tests for the LU custom calls
for CPU, and improved the checking of LU results by checking
the invariant for the result as opposed to checking goldens.

PiperOrigin-RevId: 542852925
2023-06-23 07:25:24 -07:00
jax authors
935579db07 Merge pull request #16514 from ayaka14732:main
PiperOrigin-RevId: 542790062
2023-06-23 01:34:50 -07:00
jax authors
7205e553f3 Merge pull request #16503 from jakevdp:sparsify-custom-jvp
PiperOrigin-RevId: 542790010
2023-06-23 01:26:41 -07:00
jax authors
1b66d97b4c Merge pull request #16428 from jakevdp:prng-tests
PiperOrigin-RevId: 542789970
2023-06-23 01:18:07 -07:00