15578 Commits

Author SHA1 Message Date
jax authors
3c1f3abba2 Merge pull request #15149 from sharadmv:runstate
PiperOrigin-RevId: 521809360
2023-04-04 10:56:25 -07:00
jax authors
efcd85daeb Merge pull request #15086 from shoyer:extrap
PiperOrigin-RevId: 521796315
2023-04-04 10:12:13 -07:00
jax authors
9bb3d86647 Merge pull request #15390 from jakevdp:checkify-dynamic-slice
PiperOrigin-RevId: 521790925
2023-04-04 09:54:19 -07:00
Jake VanderPlas
46297dccaf checkify: catch OOB errors in dynamic_slice
This will allow checkify tests to continue working properly after #15377
2023-04-04 08:16:59 -07:00
George Necula
35bfdc65e8 [shape_poly] Add some support for shape polymorphism for FFT, and tests
PiperOrigin-RevId: 521749241
2023-04-04 06:45:57 -07:00
Sharad Vikram
5101184ad4 Add initial implementation of a run_state primitive 2023-04-03 21:32:32 -07:00
jax authors
8a6c929678 Merge pull request #15289 from cgarciae:add-missing-api-references
PiperOrigin-RevId: 521617419
2023-04-03 18:23:04 -07:00
jax authors
b361f4cd0c Merge pull request #15169 from cgarciae:fix-lstm
PiperOrigin-RevId: 521616002
2023-04-03 18:13:19 -07:00
Yash Katariya
14b572f60d Remove _compile_replicated option from compile since it is not needed anymore and some other cosmetic fixes.
PiperOrigin-RevId: 521604489
2023-04-03 17:17:33 -07:00
Stephan Hoyer
4009005f0c Support extrapolation in jnp.interp
Fixes https://github.com/google/jax/issues/14858
2023-04-03 15:31:14 -07:00
Cristian Garcia
aa12e3597b handle seq_lengths in lstm_ref 2023-04-03 22:22:54 +00:00
Parker Schuh
c2b15a1eb8 Break out aot_test from array_test (for serialization and other aot APIs).
PiperOrigin-RevId: 521568985
2023-04-03 14:47:53 -07:00
Yash Katariya
78678ee9e1 Rename count_pjit_cache_miss with count_pjit_cpp_cache_miss because it is confusing which cache the first function is taking about as pjit has many caches
PiperOrigin-RevId: 521559652
2023-04-03 14:15:02 -07:00
Yash Katariya
6f2256ad17 Improve the error message of device_indices_map when the sharding is not divisible by the shape rather than raising an opaque assertion error
PiperOrigin-RevId: 521507810
2023-04-03 11:05:25 -07:00
George Necula
05249ec770 [jax2tf] Add more sharding tests with shape polymorphism
PiperOrigin-RevId: 521471546
2023-04-03 08:54:58 -07:00
George Necula
ff313a37a2 [jax2tf] Skip "graph" mode primitive tests on TPUs.
PiperOrigin-RevId: 521468145
2023-04-03 08:39:36 -07:00
jax authors
d743d23859 Convolution functions in TF, like- tf.nn.depthwise_conv2d_v2, tf.nn.conv2d_transpose_v2, tf.nn.conv2d_v2 all follow the same principal when it comes to padding(explained here- https://www.tensorflow.org/api_docs/python/tf/nn#notes_on_padding_2). These principal happens to match with that of lax.convolution.conv_general_dilated.
So, this CL safely uses the padding(list[tuple(int, int)]) to call tf.nn.<conv> functions

PiperOrigin-RevId: 521464565
2023-04-03 08:22:02 -07:00
George Necula
607c7c1fdd Plumbing for dynamic shapes for custom calls.
PiperOrigin-RevId: 521439418
2023-04-03 06:16:12 -07:00
jax authors
0d32724882 Merge pull request #15340 from gnecula:dim_vars3
PiperOrigin-RevId: 521424534
2023-04-03 04:44:16 -07:00
George Necula
cd35e901aa [shape_poly] Cleanup handling of dimension variables.
We unify the way we compute with dimension variables (computing
their values from the shape of the actual arguments, and also
using those values to evaluate shapes that contain dimension variables).

We remove DimExprValueMlir, and all computations with dimension variables
and DimExpr are now done by JAX interpretation, followed by lowering to
TF or StableHLO.
2023-04-03 13:33:29 +02:00
George Necula
bf2c07121b [jax2tf] Add test that compile_args[tuple_args] does not matter for serialization
PiperOrigin-RevId: 521422653
2023-04-03 04:32:34 -07:00
George Necula
2ce78ac9a8 [jax2tf] Add checks that we do not see unexpected lowered.compiler_args
Some of those compile_args change the semantics and the calling convention
for the lowered module. We want to be explicit about the ones that we
are handling.

PiperOrigin-RevId: 521419681
2023-04-03 04:13:31 -07:00
jax authors
b0a6cdbf24 Merge pull request #15341 from gnecula:tf_grad
PiperOrigin-RevId: 521409693
2023-04-03 03:15:36 -07:00
Qiao Zhang
cf599c7d3e Avoid re-constructing set. Expensive at scale.
PiperOrigin-RevId: 521310375
2023-04-02 14:42:35 -07:00
George Necula
e9fc02eb14 [jax2tf] Cleanup of handling of tf.custom_gradient
There are some incompatibilities between JAX and TF when it comes
to gradients for functions that take non-float arguments, or whose
arguments are unused. JAX uses float0 for those gradients, but this
is not a type that TF recognizes. Furthermore, under tf.function
context TF will pass `None` in place of cotangents for the outputs
that have non-float type.

Previously, the workarounds for these were in the JAX function
that was converted to obtain the TF gradient. Here we move
those workarounds in the TF-land of jax2tf.

This will enable us to expand jax_export with handling of gradients.
jax_export is pure JAX, and hence it is important to move the
TF workarounds outside of the converte JAX functions.

This is just a refactor.
2023-04-02 21:07:44 +02:00
Eugene Burmako
b8dfb97e57 Integrate StableHLO at openxla/stablehlo@7a93924
PiperOrigin-RevId: 521293524
2023-04-02 11:14:01 -07:00
George Necula
88f77bbcc6 [jax2tf] Removed call_tf tests that are not applicable anymore.
A recent change in TensorFlow makes copies of np.ndarray when they
are turned into tf.constant. This means that call_tf cannot guarantee
anymore no-copy. Removing those tests, and the paragraph in the
documentation that describes this property.

PiperOrigin-RevId: 521120090
2023-04-01 03:07:13 -07:00
Yash Katariya
2432adefc3 Add Deprecation warning if gda_serialization is imported
PiperOrigin-RevId: 521081821
2023-03-31 21:28:07 -07:00
Yash Katariya
d27a80dbfa Rename gda_serialization to array_serialization but keep gda_serialization around until it is included in a jax release so that OSS projects can be moved to array_serialization
PiperOrigin-RevId: 521055760
2023-03-31 18:07:51 -07:00
Yash Katariya
0b31e8b822 Remove dead code from pxla.py
PiperOrigin-RevId: 521003815
2023-03-31 13:51:49 -07:00
Ivy Zheng
db025df030 Stop importing old tree_util APIs conveniently and set explicit time for removal.
PiperOrigin-RevId: 521003611
2023-03-31 13:45:10 -07:00
Parker Schuh
82fcfc3851 Buffer -> Array in some pxla type annotations.
PiperOrigin-RevId: 520975371
2023-03-31 11:42:22 -07:00
Jake VanderPlas
b37c741c6f accelerate deprecation of jax.curry
PiperOrigin-RevId: 520958381
2023-03-31 10:37:39 -07:00
jax authors
ffb8352848 Merge pull request #15342 from jakevdp:doc-requirements
PiperOrigin-RevId: 520955387
2023-03-31 10:27:30 -07:00
jax authors
2841bd310e Merge pull request #15321 from jakevdp:remove-msort
PiperOrigin-RevId: 520952178
2023-03-31 10:16:18 -07:00
Zafarali Ahmed
6e00ba8bad Enable more mesh shape assignment
We now sort the mesh dims by size first. Smaller dims have fewer choices so
they should be assigned first.

PiperOrigin-RevId: 520942700
2023-03-31 09:36:16 -07:00
jax authors
dfbbc2551c Merge pull request #15317 from ROCmSoftwarePlatform:rocm_pmap_fix
PiperOrigin-RevId: 520934992
2023-03-31 09:05:07 -07:00
Peter Hawkins
abf1acf76c Replace references to jax.interpreters with jax._src.interpreters in JAX core.
PiperOrigin-RevId: 520933067
2023-03-31 08:58:00 -07:00
jax authors
182cc9857c Merge pull request #15323 from NeilGirdhar:fix_rayleigh
PiperOrigin-RevId: 520932851
2023-03-31 08:50:40 -07:00
Jake VanderPlas
9ec3ad1ce7 DOC: pin newest sphinx-book-theme 2023-03-31 08:42:34 -07:00
Jake VanderPlas
749dc1b95e Remove deprecated function jnp.msort 2023-03-31 08:24:36 -07:00
jax authors
0df2ddcf0e Merge pull request #15232 from gnecula:tf_arange
PiperOrigin-RevId: 520914838
2023-03-31 07:11:19 -07:00
George Necula
c368c69625 [shape_poly] Extend the handling of jnp.arange with shape polymorphism.
Previously, only `arange(stop, dtype=...)` was being handled in presence
of shape polymorphism. Here we extend to add support for `start` and `step`
to be also present. There are still plenty of restrictions:

   * no floating point constants are allowed among start, stop and step
   * we must resolve statically if step is positive or negative
   * we must resolve statically if the distance between start and stop
     is negative or positive.
2023-03-31 14:41:26 +02:00
jax authors
76b922aade Merge pull request #15337 from mattjj:axis-name-shadowing-2
PiperOrigin-RevId: 520838748
2023-03-30 23:01:02 -07:00
Matthew Johnson
6a2b081506 fix bug from #15335 by checking main_trace tag 2023-03-30 22:35:03 -07:00
jax authors
12bcdeb69e Merge pull request #15335 from mattjj:axis-name-shadowing
PiperOrigin-RevId: 520829991
2023-03-30 21:56:42 -07:00
Matthew Johnson
211bc29842 add assertions for axis name shadowing bugs 2023-03-30 21:31:02 -07:00
jax authors
d383ab65dc Merge pull request #15255 from eltociear:patch-6
PiperOrigin-RevId: 520814903
2023-03-30 20:38:17 -07:00
jax authors
8e17da477c Merge pull request #15322 from jakevdp:pre-commit
PiperOrigin-RevId: 520793950
2023-03-30 18:26:27 -07:00
jax authors
248ffc2ca2 Merge pull request #15329 from jakevdp:padfunc-protocol-2
PiperOrigin-RevId: 520793934
2023-03-30 18:19:43 -07:00