15334 Commits

Author SHA1 Message Date
George Necula
961e09e614 [shape_poly, call_tf] Some improvements for call_tf in a shape polymorphic program
This is another attempt to land a rolled-back change https://github.com/google/jax/pull/14734 (cl/514070997).
See b/272154366 for more details.

The use case for call_tf with shape polymorphism is when we have a JAX program
that calls into TF function, and we want to serialize the JAX program with
some shapes unknown. Previously this use case did not work, except in the special
case when the output shape of the called TF function returns statically known
shapes.

The idea is that we allow the user of call_tf to specify the output shape.
This can be done even in presence of shape polymorphism, by writing the
output shape as an expression in terms of the input shapes. This is what
other JAX primitives do, e.g., concat, so we are simply enabling call_tf
to get the same behavior.

This change should be enough for old-style jax2tf, but will require more
work for native serialization.

We also removed some old code that was trying to workaround some limitations
in shape inference in TF. I think that those workarounds are ugly, and I am
prepared to give error messages rather than keep that code. So far no
tests fail.

PiperOrigin-RevId: 515137407
2023-03-08 14:10:08 -08:00
Parker Schuh
942e79ffe3 Fix tests that were adding __annotations__ for some reason to CompileOptions.
PiperOrigin-RevId: 515133225
2023-03-08 13:55:03 -08:00
jax authors
9c4db8c962 Merge pull request #14633 from mattjj:shmap-test-vmap
PiperOrigin-RevId: 515117185
2023-03-08 12:56:54 -08:00
jax authors
6634600c46 Merge pull request #14840 from mattjj:relu6-grad-at-0-and-6
PiperOrigin-RevId: 515106619
2023-03-08 12:15:38 -08:00
Matthew Johnson
704b8735df [shard-map] add systematic vmap-of-shmap tests 2023-03-08 12:08:37 -08:00
Jake VanderPlas
44082be103 Set ArrayImpl.__name__ to ArrayImpl
Fixes https://github.com/google/jax/issues/14768

PiperOrigin-RevId: 515097907
2023-03-08 11:43:29 -08:00
jax authors
2c2dfbe89b Merge pull request #14841 from hawkinsp:pmap
PiperOrigin-RevId: 515089941
2023-03-08 11:21:51 -08:00
jax authors
0f2d48b078 Merge pull request #14838 from jakevdp:numpy-indexing-imports
PiperOrigin-RevId: 515089511
2023-03-08 11:15:24 -08:00
Jake VanderPlas
c8c269f5f5 internal: avoid unused imports in lax_numpy 2023-03-08 10:29:04 -08:00
jax authors
5daa4f3dc6 Merge pull request #14855 from jakevdp:bcoo-integer-pow
PiperOrigin-RevId: 515069303
2023-03-08 10:10:08 -08:00
jax authors
794d8b7128 Expose ArgInfo
All the classes exposed here have a public `arg_infos` attribute, whose type is 'PyTree of ArgInfo'. Any code that wants to construct, implement or manipulate the types exposed in this file will likely need access to the type definition of their public member variables.

PiperOrigin-RevId: 515067491
2023-03-08 10:03:17 -08:00
Jake VanderPlas
f32e72da2a [sparse] add support for integer_pow 2023-03-08 09:24:52 -08:00
jax authors
d4d73224fc Merge pull request #14829 from jakevdp:spdot-general-nse
PiperOrigin-RevId: 515054108
2023-03-08 09:08:46 -08:00
Peter Hawkins
6c2e240634 Add argnames and resultnames to pmap. 2023-03-08 10:13:30 -05:00
Jake VanderPlas
4180f8bf7b [sparse] improve worst-case nse in spdot_general 2023-03-07 20:46:27 -08:00
Parker Schuh
d62fc88fb1 Roll back #14792
Breaks tests. lax.sub requires arguments to have the same dtypes, got float32, float64. (Tip: jnp.subtract is a similar function that does automatic type promotion on inputs).

PiperOrigin-RevId: 514897538
2023-03-07 18:31:19 -08:00
Matthew Johnson
9c39b6f70c update relu6 grad at 0 and 6 to match pytorch convention 2023-03-07 17:30:17 -08:00
jax authors
a51caababf Merge pull request #14823 from jakevdp:rand-bcoo-indtype
PiperOrigin-RevId: 514877270
2023-03-07 16:49:52 -08:00
jax authors
cc694c66ce Merge pull request #14798 from nicholasjng:custom-linear-solve-batching-fix
PiperOrigin-RevId: 514873672
2023-03-07 16:39:18 -08:00
jax authors
75d8fb0b1d Merge pull request #14836 from shoyer:searchsorted
PiperOrigin-RevId: 514873323
2023-03-07 16:32:55 -08:00
jax authors
103822ba87 Merge pull request #14835 from patrick-kidger:patch-1
PiperOrigin-RevId: 514870092
2023-03-07 16:20:49 -08:00
Patrick Kidger
39fa1cadce
ravel_pytree now accepted by static linters 2023-03-07 23:35:48 +00:00
Stephan Hoyer
d4f70c8071 Add "compare_all" method to searchsorted 2023-03-07 16:34:24 -07:00
jax authors
b4ec72deae Merge pull request #14792 from b0nce:fix-scipy-stats
PiperOrigin-RevId: 514790201
2023-03-07 11:24:37 -08:00
George Necula
44e6d3cd6b Internal change
PiperOrigin-RevId: 514785311
2023-03-07 11:11:44 -08:00
jax authors
7b4863d210 Merge pull request #14782 from Gattocrucco:vectorize-no-broadcast
PiperOrigin-RevId: 514784797
2023-03-07 11:05:19 -08:00
Parker Schuh
61e589bd20 Convert testShardArgs to handle pxla.Chunked sharding properly.
Chunked + Unstacked shardings are invalid, so delete or update those
tests.

PiperOrigin-RevId: 514767811
2023-03-07 10:10:47 -08:00
Sholto Douglas
96219128ba Fixes "Unhashable type set" bug
PiperOrigin-RevId: 514754424
2023-03-07 09:23:44 -08:00
pizzud
22cbf95e07 lax_vmap_test: Extend timeout so that the TPU variant can run in ASAN.
Unfortunately we can't conditionally change the timeout, as size and timeout
are both non-configurable even if jax_test supported setting the size.

PiperOrigin-RevId: 514745247
2023-03-07 08:49:42 -08:00
Jake VanderPlas
e46062d72c [sparse] more careful handling of index_dtype in rand_bcoo 2023-03-07 08:34:10 -08:00
Adam Roberts
ea68198f37 Correctly cast crc32 to unit32 instead of int32 to avoid the following warnings:
PiperOrigin-RevId: 514736060
2023-03-07 08:11:12 -08:00
jax authors
35f86804c1 Merge pull request #14811 from jakevdp:fix-cusparse-tests
PiperOrigin-RevId: 514725438
2023-03-07 07:22:24 -08:00
jax authors
88d5a4110b Merge pull request #14749 from gnecula:tf_bincount_poly
PiperOrigin-RevId: 514712599
2023-03-07 06:16:32 -08:00
George Necula
595836e69c [shape_poly] Disable native lowering test for cumsum on GPU
The test fails because of recent changes in the lowering rule for
associative reductions on GPU.

PiperOrigin-RevId: 514694932
2023-03-07 04:42:19 -08:00
Giacomo Petrillo
95a5b4e48a avoid unnecessary broadcasting in jax.numpy.vectorize 2023-03-07 11:56:34 +01:00
Nicholas Junge
27b26515fe Add regression test 2023-03-07 09:30:03 +01:00
George Necula
afe4f8ed1a [shape_poly] Add support for shape polymorphism for jnp.{argsort,bincount,insert,nonzero} 2023-03-07 08:29:07 +01:00
Misha
feb9ab33af Fixed loc and scale parameters for logistic distribution. CDF and SF have been added for several distributions, including cauchy, gamma, logistic, chi2 and beta. ISF and PPF have also been added for cauchy and logistic. 2023-03-07 07:56:47 +01:00
jax authors
da3b75aacc Merge pull request #14813 from mattjj:14780
PiperOrigin-RevId: 514627923
2023-03-06 22:17:55 -08:00
Matthew Johnson
b05975b964 add result info to mhlo, fixes #14780
incidentally fixes #14787
2023-03-06 21:21:26 -08:00
Jake VanderPlas
b527bcaa3c [sparse] fix GPU warnings in cusparse test 2023-03-06 17:39:28 -08:00
jax authors
0ec82f4d62 Merge pull request #14809 from jakevdp:fix-oob-correction
PiperOrigin-RevId: 514572448
2023-03-06 17:25:16 -08:00
Jake VanderPlas
6d750def36 [sparse] fix OOB index correction 2023-03-06 17:00:15 -08:00
Parker Schuh
dd5e04ca88 Ensure that the sharding specs are always set to match the number of
addressable devices when pmap_nreps is set.

PiperOrigin-RevId: 514537352
2023-03-06 15:02:25 -08:00
jax authors
00f1abe401 Disable 2 failing jax tests.
PiperOrigin-RevId: 514515343
2023-03-06 13:50:40 -08:00
jax authors
7d154103e3 Merge pull request #14789 from patrick-kidger:patch-2
PiperOrigin-RevId: 514490253
2023-03-06 12:27:53 -08:00
Colin Gaffney
b4527f2435 Modify get_tensorstore_spec to support ocdbt driver option.
PiperOrigin-RevId: 514485035
2023-03-06 12:10:49 -08:00
Tianjian Lu
7bcd490b69 [sparse] add low-level primitives wrapping cuda csr spmv and spmm.
PiperOrigin-RevId: 514473374
2023-03-06 11:34:30 -08:00
jax authors
46e4c6bf35 Merge pull request #14695 from thisiscam:cli_debugger_readline
PiperOrigin-RevId: 514462767
2023-03-06 11:01:59 -08:00
Patrick Kidger
17afaf67b9 Add _ScalarMeta(dtype=...) field for static type checkers 2023-03-06 10:49:30 -08:00