George Necula
961e09e614
[shape_poly, call_tf] Some improvements for call_tf in a shape polymorphic program
...
This is another attempt to land a rolled-back change https://github.com/google/jax/pull/14734 (cl/514070997).
See b/272154366 for more details.
The use case for call_tf with shape polymorphism is when we have a JAX program
that calls into TF function, and we want to serialize the JAX program with
some shapes unknown. Previously this use case did not work, except in the special
case when the output shape of the called TF function returns statically known
shapes.
The idea is that we allow the user of call_tf to specify the output shape.
This can be done even in presence of shape polymorphism, by writing the
output shape as an expression in terms of the input shapes. This is what
other JAX primitives do, e.g., concat, so we are simply enabling call_tf
to get the same behavior.
This change should be enough for old-style jax2tf, but will require more
work for native serialization.
We also removed some old code that was trying to workaround some limitations
in shape inference in TF. I think that those workarounds are ugly, and I am
prepared to give error messages rather than keep that code. So far no
tests fail.
PiperOrigin-RevId: 515137407
2023-03-08 14:10:08 -08:00
Parker Schuh
942e79ffe3
Fix tests that were adding __annotations__ for some reason to CompileOptions.
...
PiperOrigin-RevId: 515133225
2023-03-08 13:55:03 -08:00
jax authors
9c4db8c962
Merge pull request #14633 from mattjj:shmap-test-vmap
...
PiperOrigin-RevId: 515117185
2023-03-08 12:56:54 -08:00
jax authors
6634600c46
Merge pull request #14840 from mattjj:relu6-grad-at-0-and-6
...
PiperOrigin-RevId: 515106619
2023-03-08 12:15:38 -08:00
Matthew Johnson
704b8735df
[shard-map] add systematic vmap-of-shmap tests
2023-03-08 12:08:37 -08:00
Jake VanderPlas
44082be103
Set ArrayImpl.__name__ to ArrayImpl
...
Fixes https://github.com/google/jax/issues/14768
PiperOrigin-RevId: 515097907
2023-03-08 11:43:29 -08:00
jax authors
2c2dfbe89b
Merge pull request #14841 from hawkinsp:pmap
...
PiperOrigin-RevId: 515089941
2023-03-08 11:21:51 -08:00
jax authors
0f2d48b078
Merge pull request #14838 from jakevdp:numpy-indexing-imports
...
PiperOrigin-RevId: 515089511
2023-03-08 11:15:24 -08:00
Jake VanderPlas
c8c269f5f5
internal: avoid unused imports in lax_numpy
2023-03-08 10:29:04 -08:00
jax authors
5daa4f3dc6
Merge pull request #14855 from jakevdp:bcoo-integer-pow
...
PiperOrigin-RevId: 515069303
2023-03-08 10:10:08 -08:00
jax authors
794d8b7128
Expose ArgInfo
...
All the classes exposed here have a public `arg_infos` attribute, whose type is 'PyTree of ArgInfo'. Any code that wants to construct, implement or manipulate the types exposed in this file will likely need access to the type definition of their public member variables.
PiperOrigin-RevId: 515067491
2023-03-08 10:03:17 -08:00
Jake VanderPlas
f32e72da2a
[sparse] add support for integer_pow
2023-03-08 09:24:52 -08:00
jax authors
d4d73224fc
Merge pull request #14829 from jakevdp:spdot-general-nse
...
PiperOrigin-RevId: 515054108
2023-03-08 09:08:46 -08:00
Peter Hawkins
6c2e240634
Add argnames and resultnames to pmap.
2023-03-08 10:13:30 -05:00
Jake VanderPlas
4180f8bf7b
[sparse] improve worst-case nse in spdot_general
2023-03-07 20:46:27 -08:00
Parker Schuh
d62fc88fb1
Roll back #14792
...
Breaks tests. lax.sub requires arguments to have the same dtypes, got float32, float64. (Tip: jnp.subtract is a similar function that does automatic type promotion on inputs).
PiperOrigin-RevId: 514897538
2023-03-07 18:31:19 -08:00
Matthew Johnson
9c39b6f70c
update relu6 grad at 0 and 6 to match pytorch convention
2023-03-07 17:30:17 -08:00
jax authors
a51caababf
Merge pull request #14823 from jakevdp:rand-bcoo-indtype
...
PiperOrigin-RevId: 514877270
2023-03-07 16:49:52 -08:00
jax authors
cc694c66ce
Merge pull request #14798 from nicholasjng:custom-linear-solve-batching-fix
...
PiperOrigin-RevId: 514873672
2023-03-07 16:39:18 -08:00
jax authors
75d8fb0b1d
Merge pull request #14836 from shoyer:searchsorted
...
PiperOrigin-RevId: 514873323
2023-03-07 16:32:55 -08:00
jax authors
103822ba87
Merge pull request #14835 from patrick-kidger:patch-1
...
PiperOrigin-RevId: 514870092
2023-03-07 16:20:49 -08:00
Patrick Kidger
39fa1cadce
ravel_pytree
now accepted by static linters
2023-03-07 23:35:48 +00:00
Stephan Hoyer
d4f70c8071
Add "compare_all" method to searchsorted
2023-03-07 16:34:24 -07:00
jax authors
b4ec72deae
Merge pull request #14792 from b0nce:fix-scipy-stats
...
PiperOrigin-RevId: 514790201
2023-03-07 11:24:37 -08:00
George Necula
44e6d3cd6b
Internal change
...
PiperOrigin-RevId: 514785311
2023-03-07 11:11:44 -08:00
jax authors
7b4863d210
Merge pull request #14782 from Gattocrucco:vectorize-no-broadcast
...
PiperOrigin-RevId: 514784797
2023-03-07 11:05:19 -08:00
Parker Schuh
61e589bd20
Convert testShardArgs to handle pxla.Chunked sharding properly.
...
Chunked + Unstacked shardings are invalid, so delete or update those
tests.
PiperOrigin-RevId: 514767811
2023-03-07 10:10:47 -08:00
Sholto Douglas
96219128ba
Fixes "Unhashable type set" bug
...
PiperOrigin-RevId: 514754424
2023-03-07 09:23:44 -08:00
pizzud
22cbf95e07
lax_vmap_test: Extend timeout so that the TPU variant can run in ASAN.
...
Unfortunately we can't conditionally change the timeout, as size and timeout
are both non-configurable even if jax_test supported setting the size.
PiperOrigin-RevId: 514745247
2023-03-07 08:49:42 -08:00
Jake VanderPlas
e46062d72c
[sparse] more careful handling of index_dtype in rand_bcoo
2023-03-07 08:34:10 -08:00
Adam Roberts
ea68198f37
Correctly cast crc32 to unit32 instead of int32 to avoid the following warnings:
...
PiperOrigin-RevId: 514736060
2023-03-07 08:11:12 -08:00
jax authors
35f86804c1
Merge pull request #14811 from jakevdp:fix-cusparse-tests
...
PiperOrigin-RevId: 514725438
2023-03-07 07:22:24 -08:00
jax authors
88d5a4110b
Merge pull request #14749 from gnecula:tf_bincount_poly
...
PiperOrigin-RevId: 514712599
2023-03-07 06:16:32 -08:00
George Necula
595836e69c
[shape_poly] Disable native lowering test for cumsum on GPU
...
The test fails because of recent changes in the lowering rule for
associative reductions on GPU.
PiperOrigin-RevId: 514694932
2023-03-07 04:42:19 -08:00
Giacomo Petrillo
95a5b4e48a
avoid unnecessary broadcasting in jax.numpy.vectorize
2023-03-07 11:56:34 +01:00
Nicholas Junge
27b26515fe
Add regression test
2023-03-07 09:30:03 +01:00
George Necula
afe4f8ed1a
[shape_poly] Add support for shape polymorphism for jnp.{argsort,bincount,insert,nonzero}
2023-03-07 08:29:07 +01:00
Misha
feb9ab33af
Fixed loc and scale parameters for logistic distribution. CDF and SF have been added for several distributions, including cauchy, gamma, logistic, chi2 and beta. ISF and PPF have also been added for cauchy and logistic.
2023-03-07 07:56:47 +01:00
jax authors
da3b75aacc
Merge pull request #14813 from mattjj:14780
...
PiperOrigin-RevId: 514627923
2023-03-06 22:17:55 -08:00
Matthew Johnson
b05975b964
add result info to mhlo, fixes #14780
...
incidentally fixes #14787
2023-03-06 21:21:26 -08:00
Jake VanderPlas
b527bcaa3c
[sparse] fix GPU warnings in cusparse test
2023-03-06 17:39:28 -08:00
jax authors
0ec82f4d62
Merge pull request #14809 from jakevdp:fix-oob-correction
...
PiperOrigin-RevId: 514572448
2023-03-06 17:25:16 -08:00
Jake VanderPlas
6d750def36
[sparse] fix OOB index correction
2023-03-06 17:00:15 -08:00
Parker Schuh
dd5e04ca88
Ensure that the sharding specs are always set to match the number of
...
addressable devices when pmap_nreps is set.
PiperOrigin-RevId: 514537352
2023-03-06 15:02:25 -08:00
jax authors
00f1abe401
Disable 2 failing jax tests.
...
PiperOrigin-RevId: 514515343
2023-03-06 13:50:40 -08:00
jax authors
7d154103e3
Merge pull request #14789 from patrick-kidger:patch-2
...
PiperOrigin-RevId: 514490253
2023-03-06 12:27:53 -08:00
Colin Gaffney
b4527f2435
Modify get_tensorstore_spec to support ocdbt driver option.
...
PiperOrigin-RevId: 514485035
2023-03-06 12:10:49 -08:00
Tianjian Lu
7bcd490b69
[sparse] add low-level primitives wrapping cuda csr spmv and spmm.
...
PiperOrigin-RevId: 514473374
2023-03-06 11:34:30 -08:00
jax authors
46e4c6bf35
Merge pull request #14695 from thisiscam:cli_debugger_readline
...
PiperOrigin-RevId: 514462767
2023-03-06 11:01:59 -08:00
Patrick Kidger
17afaf67b9
Add _ScalarMeta(dtype=...)
field for static type checkers
2023-03-06 10:49:30 -08:00