6165 Commits

Author SHA1 Message Date
Tao Wang
6eb3096461 Enable to set fdo_profile through XLA python client.
PiperOrigin-RevId: 547303330
2023-07-11 14:47:39 -07:00
Jake VanderPlas
a29d4bcd33 remove deprecation warning test in preparation for removing deprecated APIs
PiperOrigin-RevId: 547229078
2023-07-11 10:52:10 -07:00
Juliana Franco
f81a48a819 Makes it possible to lower primitives with user-defined lowering rules.
PiperOrigin-RevId: 547228102
2023-07-11 10:26:07 -07:00
jax authors
17c4b57f97 Merge pull request #16671 from jakevdp:std-args
PiperOrigin-RevId: 547227744
2023-07-11 10:25:53 -07:00
jax authors
e894e4817a Remove deprecated compiler_ir from Compiled
PiperOrigin-RevId: 547211085
2023-07-11 09:24:48 -07:00
jax authors
2fa6a9c9bf Allow other backends to run the array_test.py test.
PiperOrigin-RevId: 547191886
2023-07-11 08:05:25 -07:00
jax authors
f7a71e4ca5 Merge pull request #16543 from ROCmSoftwarePlatform:rocm-enable-eighidentity-test
PiperOrigin-RevId: 547179014
2023-07-11 07:13:28 -07:00
jax authors
3ec5f73db0 Merge pull request #16542 from ROCmSoftwarePlatform:rocm-enable-svdontiny-test
PiperOrigin-RevId: 547178719
2023-07-11 07:04:45 -07:00
Roy Frostig
1ad0a11897 AOT: better error messages on call signature mismatch
Also update error example in AOT docs.
2023-07-10 22:10:50 -07:00
Jake VanderPlas
1b3da85758 Fix scatter batching rule for scatter_apply
The issue is that the batching rule assumes that each scatter variant
always has the same update_jaxpr. This is not true of scatter_apply, which
lowers to scatter with a custom update_jaxpr. To address this, we change
the batching rule such that it re-uses the input jaxpr rather than always
re-generating it.
2023-07-10 16:42:45 -07:00
jax authors
f4eed78e90 Merge pull request #16645 from treyra:main
PiperOrigin-RevId: 546987247
2023-07-10 14:41:31 -07:00
Jake VanderPlas
d7bb9f85d6 NumpySignaturesTest: account for 'mean' param to std/var 2023-07-10 09:56:17 -07:00
treyra
b0c309a25c Added test for vmap inconsistent sized arrays msg 2023-07-09 20:46:40 -07:00
jax authors
1795b12a9f Merge pull request #16654 from jakevdp:ml-dtypes-version
PiperOrigin-RevId: 546366165
2023-07-07 13:13:55 -07:00
Jake VanderPlas
9962065deb Require ml_dtypes>=0.2 2023-07-07 12:07:44 -07:00
Alexey Radul
defe71228c Clearer test names. 2023-07-07 09:23:33 -04:00
Alexey Radul
aa3c49f134 Test a different configuration of einsum.
This version stresses my transpose_ragged_axes method, which, it
seems, was interpreting the permutation the wrong way.  Fixed.
2023-07-07 09:23:33 -04:00
Alexey Radul
89dd69ea2d Test and implement ragged slicing.
This touches _gather_batching_rule because slicing is implemented as a
gather, but we only test the case exercised by the slice that occurs
in our test transformer model, namely the unstack operation
  q, k, v = qkv
(which turns into three slices on an non-batched and non-ragged axis).

Co-authored-by: Matthew Johnson <mattjj@google.com>
2023-07-07 09:23:33 -04:00
Alexey Radul
6f09fe840e Better error message when broadcasting ragged to static shape.
Co-authored-by: Matthew Johnson <mattjj@google.com>
2023-07-07 09:23:29 -04:00
Sharad Vikram
c446b42522 Add discharge rules for scan/while 2023-07-06 22:30:35 +00:00
Roy Frostig
ce9c2d650a rename seed_prng test method to make_key 2023-07-05 15:26:30 -07:00
Roy Frostig
ff70255af9 consistently seed keys indirectly by test class method in LaxRandomTest 2023-07-05 15:18:54 -07:00
Roy Frostig
556c1123cf parameterize two random tests over key constructors 2023-07-05 15:18:54 -07:00
Roy Frostig
c710c7578d move and remove code in random_test 2023-07-05 15:18:54 -07:00
jax authors
7c7051a4cc Merge pull request #16607 from froystig:random-test-double-threefry
PiperOrigin-RevId: 545799083
2023-07-05 15:15:35 -07:00
Roy Frostig
30542bd5bd match behavior of double-threefry test RNG and standard threefry RNG
This also lets us avoid a guard on `config.jax_enable_custom_prng` in
random tests.
2023-07-05 15:01:12 -07:00
Roy Frostig
09af6b1e01 test non-threefry RNGs across both typed and raw key formats
This also lets us remove some test guards on `config.jax_enable_custom_prng`.
2023-07-05 13:54:14 -07:00
Roy Frostig
bc44b99d05 avoid raw key arrays in typed key sharding test
This also lets us remove a guard on `config.jax_enable_custom_prng` in
random tests.
2023-06-30 20:38:26 -07:00
Roy Frostig
f8dee51d9a increase random test coverage over RNG key constructors and representations
This is an incremental change to our random tests that primarily:

* Increases test coverage of both key constructors (`random.key` and
  `random.PRNGKey`), often by parameterizing tests over both.

* Increases test coverage of both key representations (typed key
  arrays and `uint32` arrays).

* Removes a handful of guards on `config.jax_enable_custom_prng`,
  either replacing them with `isinstance` checks for typed keys or
  removing them altogether if possible.

* Makes a handful of other individual test improvements and fixes, and
  leaves comments for more.
2023-06-30 20:26:29 -07:00
Roy Frostig
9b346861a9 add impl option to random key constructors that picks the RNG implementation
This change primarily adds an optional argument to both old- and
new-style random key constructors. The option determines the PRNG
implementation for the key by name, overriding any default
implementation determined by configuration flags.

Along the way, looking ahead:

* We can deprecate the (anyway underused) individual explicit key
  constructors like `jax.random.threefr2x32_key` in favor of this
  option.

* Some day, instead of only accepting RNG implementations by name
  (string), we can also accept the output of some custom PRNG
  implementation API that we expose, maybe via `jax.extend.random`
  (corresponding roughly to the current `_src.prng.PRNGImpl`).
2023-06-30 16:42:22 -07:00
Jake VanderPlas
d0e75ca117 Require index update optional arguments to be passed by keyword.
Passing these keywords by position has been deprecated and has raised a warning since JAX v0.4.7 (Released 27 March 2023)

PiperOrigin-RevId: 544620172
2023-06-30 04:30:34 -07:00
Jake VanderPlas
18bbc96279 Fix integer overflow in gather batching rule 2023-06-27 21:45:45 -07:00
Roy Frostig
14f32653a1 resolve conditionals to default "shared operand form" more often
If both the second and third operand of a `lax.cond` call are callable, then
resolve it as a new-style (default) conditional, where both branches act on the
same operands.

This changes the behavior of five-argument `lax.cond` calls. It is a breaking
change for callers using the old-style `cond` calling convention (`pred`,
`true_arg`, `true_fn`, `false_arg`, `false_fn`) with a callable `true_arg`.

PiperOrigin-RevId: 543912445
2023-06-27 18:49:16 -07:00
Jake VanderPlas
30d1a8a80f Add jax.scipy.stats.binom 2023-06-27 03:41:38 -07:00
Yash Katariya
744a64fce6 Make sharding on ShapeDtypeStruct a property that always exists. The previous behavior was it only existed if sharding was not None.
sharding=None means that JAX is free to choose whatever sharding it wants. As it stands, jax will choose to mark the input as replicated but JAX reserves the right to change that as it sees fit.
PiperOrigin-RevId: 543630595
2023-06-26 21:46:50 -07:00
Parker Schuh
819f731e8d jax.lax.collapse now takes Nones for stop_dimension.
PiperOrigin-RevId: 543598626
2023-06-26 18:30:34 -07:00
Jieying Luo
21588a30a9 [PJRT C API] Add related C type definitions for key value get/put callback, as well as conversion between C and cpp types.
This is similar to how send/receive callback are implemented.

Update make_c_api_client to take key value get/put callback generated from distributed client, and optiosn of node_id and num_nodes.

PiperOrigin-RevId: 543441403
2023-06-26 08:08:52 -07:00
Peter Hawkins
816ba91263 Use lower-case PEP 585 names for types.
Issue https://github.com/google/jax/issues/16537

PiperOrigin-RevId: 542969282
2023-06-23 15:12:14 -07:00
jax authors
f67acee129 Merge pull request #16430 from jakevdp:bool-error
PiperOrigin-RevId: 542951181
2023-06-23 14:00:12 -07:00
jax authors
01a16f5914 Merge pull request #16487 from jakevdp:convolve-dtype
PiperOrigin-RevId: 542929304
2023-06-23 12:32:36 -07:00
Rahul Batra
2650c14cf5 [ROCm]: Re-enable EighIdentity test 2023-06-23 17:51:43 +00:00
Rahul Batra
c5bf05d66b [ROCm]: Re-enable SvdOnTinyElement test 2023-06-23 17:50:00 +00:00
Jake VanderPlas
ad35702934 Drop support for numpy 1.21
This is in accordance with NEP 29 and https://jax.readthedocs.io/en/latest/deprecation.html
2023-06-23 10:28:26 -07:00
jax authors
63415a9184 Merge pull request #16386 from axch:ragged-einsum
PiperOrigin-RevId: 542887557
2023-06-23 10:00:07 -07:00
Peter Hawkins
0adfafe293 Relax test tolerances.
This makes the tests pass on CPU with a slightly different seed (+ 1).

PiperOrigin-RevId: 542877795
2023-06-23 09:22:11 -07:00
jax authors
7205e553f3 Merge pull request #16503 from jakevdp:sparsify-custom-jvp
PiperOrigin-RevId: 542790010
2023-06-23 01:26:41 -07:00
Jake VanderPlas
39645b5c20 Custom PRNG: improve test coverage when enable_custom_prng=false
We're now moving to a world where custom PRNG should exist side-by-side with the old PRNG
implementation. This change improves test coverage for that, by enabling relevant tests
even when the flag is set to False.
2023-06-23 00:29:11 -07:00
Jake VanderPlas
b6d544549b [sparse] support custom JVP in sparsify 2023-06-23 00:27:19 -07:00
Yash Katariya
fc0dcd15a2 Copybara import of the project:
--
57af5360a1ca1356dbf7760c76e241f7134ef6dd by Jake VanderPlas <jakevdp@google.com>:

[Roll forward] Update required Python version to 3.9

PiperOrigin-RevId: 542728213
2023-06-22 18:58:30 -07:00
Skye Wanderman-Milne
10424c5972 Update JAX's XlaExecutable.cost_analysis and related plumbing so it works on Cloud TPU
* Exposes LoadedExecutable.cost_analysis via pybind
* Updates XlaExecutable.cost_analysis to try
  LoadedExecutable.cost_analysis, then fallback to the client method.

PiperOrigin-RevId: 542671990
2023-06-22 14:43:00 -07:00