jax authors
3c1f3abba2
Merge pull request #15149 from sharadmv:runstate
...
PiperOrigin-RevId: 521809360
2023-04-04 10:56:25 -07:00
jax authors
efcd85daeb
Merge pull request #15086 from shoyer:extrap
...
PiperOrigin-RevId: 521796315
2023-04-04 10:12:13 -07:00
jax authors
9bb3d86647
Merge pull request #15390 from jakevdp:checkify-dynamic-slice
...
PiperOrigin-RevId: 521790925
2023-04-04 09:54:19 -07:00
Jake VanderPlas
46297dccaf
checkify: catch OOB errors in dynamic_slice
...
This will allow checkify tests to continue working properly after #15377
2023-04-04 08:16:59 -07:00
George Necula
35bfdc65e8
[shape_poly] Add some support for shape polymorphism for FFT, and tests
...
PiperOrigin-RevId: 521749241
2023-04-04 06:45:57 -07:00
Sharad Vikram
5101184ad4
Add initial implementation of a run_state
primitive
2023-04-03 21:32:32 -07:00
jax authors
8a6c929678
Merge pull request #15289 from cgarciae:add-missing-api-references
...
PiperOrigin-RevId: 521617419
2023-04-03 18:23:04 -07:00
jax authors
b361f4cd0c
Merge pull request #15169 from cgarciae:fix-lstm
...
PiperOrigin-RevId: 521616002
2023-04-03 18:13:19 -07:00
Yash Katariya
14b572f60d
Remove _compile_replicated option from compile
since it is not needed anymore and some other cosmetic fixes.
...
PiperOrigin-RevId: 521604489
2023-04-03 17:17:33 -07:00
Stephan Hoyer
4009005f0c
Support extrapolation in jnp.interp
...
Fixes https://github.com/google/jax/issues/14858
2023-04-03 15:31:14 -07:00
Cristian Garcia
aa12e3597b
handle seq_lengths in lstm_ref
2023-04-03 22:22:54 +00:00
Parker Schuh
c2b15a1eb8
Break out aot_test from array_test (for serialization and other aot APIs).
...
PiperOrigin-RevId: 521568985
2023-04-03 14:47:53 -07:00
Yash Katariya
78678ee9e1
Rename count_pjit_cache_miss
with count_pjit_cpp_cache_miss
because it is confusing which cache the first function is taking about as pjit has many caches
...
PiperOrigin-RevId: 521559652
2023-04-03 14:15:02 -07:00
Yash Katariya
6f2256ad17
Improve the error message of device_indices_map when the sharding is not divisible by the shape rather than raising an opaque assertion error
...
PiperOrigin-RevId: 521507810
2023-04-03 11:05:25 -07:00
George Necula
05249ec770
[jax2tf] Add more sharding tests with shape polymorphism
...
PiperOrigin-RevId: 521471546
2023-04-03 08:54:58 -07:00
George Necula
ff313a37a2
[jax2tf] Skip "graph" mode primitive tests on TPUs.
...
PiperOrigin-RevId: 521468145
2023-04-03 08:39:36 -07:00
jax authors
d743d23859
Convolution functions in TF, like- tf.nn.depthwise_conv2d_v2, tf.nn.conv2d_transpose_v2, tf.nn.conv2d_v2 all follow the same principal when it comes to padding(explained here- https://www.tensorflow.org/api_docs/python/tf/nn#notes_on_padding_2 ). These principal happens to match with that of lax.convolution.conv_general_dilated
.
...
So, this CL safely uses the padding(list[tuple(int, int)]) to call tf.nn.<conv> functions
PiperOrigin-RevId: 521464565
2023-04-03 08:22:02 -07:00
George Necula
607c7c1fdd
Plumbing for dynamic shapes for custom calls.
...
PiperOrigin-RevId: 521439418
2023-04-03 06:16:12 -07:00
jax authors
0d32724882
Merge pull request #15340 from gnecula:dim_vars3
...
PiperOrigin-RevId: 521424534
2023-04-03 04:44:16 -07:00
George Necula
cd35e901aa
[shape_poly] Cleanup handling of dimension variables.
...
We unify the way we compute with dimension variables (computing
their values from the shape of the actual arguments, and also
using those values to evaluate shapes that contain dimension variables).
We remove DimExprValueMlir, and all computations with dimension variables
and DimExpr are now done by JAX interpretation, followed by lowering to
TF or StableHLO.
2023-04-03 13:33:29 +02:00
George Necula
bf2c07121b
[jax2tf] Add test that compile_args[tuple_args] does not matter for serialization
...
PiperOrigin-RevId: 521422653
2023-04-03 04:32:34 -07:00
George Necula
2ce78ac9a8
[jax2tf] Add checks that we do not see unexpected lowered.compiler_args
...
Some of those compile_args change the semantics and the calling convention
for the lowered module. We want to be explicit about the ones that we
are handling.
PiperOrigin-RevId: 521419681
2023-04-03 04:13:31 -07:00
jax authors
b0a6cdbf24
Merge pull request #15341 from gnecula:tf_grad
...
PiperOrigin-RevId: 521409693
2023-04-03 03:15:36 -07:00
Qiao Zhang
cf599c7d3e
Avoid re-constructing set. Expensive at scale.
...
PiperOrigin-RevId: 521310375
2023-04-02 14:42:35 -07:00
George Necula
e9fc02eb14
[jax2tf] Cleanup of handling of tf.custom_gradient
...
There are some incompatibilities between JAX and TF when it comes
to gradients for functions that take non-float arguments, or whose
arguments are unused. JAX uses float0 for those gradients, but this
is not a type that TF recognizes. Furthermore, under tf.function
context TF will pass `None` in place of cotangents for the outputs
that have non-float type.
Previously, the workarounds for these were in the JAX function
that was converted to obtain the TF gradient. Here we move
those workarounds in the TF-land of jax2tf.
This will enable us to expand jax_export with handling of gradients.
jax_export is pure JAX, and hence it is important to move the
TF workarounds outside of the converte JAX functions.
This is just a refactor.
2023-04-02 21:07:44 +02:00
Eugene Burmako
b8dfb97e57
Integrate StableHLO at openxla/stablehlo@7a93924
...
PiperOrigin-RevId: 521293524
2023-04-02 11:14:01 -07:00
George Necula
88f77bbcc6
[jax2tf] Removed call_tf tests that are not applicable anymore.
...
A recent change in TensorFlow makes copies of np.ndarray when they
are turned into tf.constant. This means that call_tf cannot guarantee
anymore no-copy. Removing those tests, and the paragraph in the
documentation that describes this property.
PiperOrigin-RevId: 521120090
2023-04-01 03:07:13 -07:00
Yash Katariya
2432adefc3
Add Deprecation warning if gda_serialization is imported
...
PiperOrigin-RevId: 521081821
2023-03-31 21:28:07 -07:00
Yash Katariya
d27a80dbfa
Rename gda_serialization to array_serialization but keep gda_serialization around until it is included in a jax release so that OSS projects can be moved to array_serialization
...
PiperOrigin-RevId: 521055760
2023-03-31 18:07:51 -07:00
Yash Katariya
0b31e8b822
Remove dead code from pxla.py
...
PiperOrigin-RevId: 521003815
2023-03-31 13:51:49 -07:00
Ivy Zheng
db025df030
Stop importing old tree_util
APIs conveniently and set explicit time for removal.
...
PiperOrigin-RevId: 521003611
2023-03-31 13:45:10 -07:00
Parker Schuh
82fcfc3851
Buffer -> Array in some pxla type annotations.
...
PiperOrigin-RevId: 520975371
2023-03-31 11:42:22 -07:00
Jake VanderPlas
b37c741c6f
accelerate deprecation of jax.curry
...
PiperOrigin-RevId: 520958381
2023-03-31 10:37:39 -07:00
jax authors
ffb8352848
Merge pull request #15342 from jakevdp:doc-requirements
...
PiperOrigin-RevId: 520955387
2023-03-31 10:27:30 -07:00
jax authors
2841bd310e
Merge pull request #15321 from jakevdp:remove-msort
...
PiperOrigin-RevId: 520952178
2023-03-31 10:16:18 -07:00
Zafarali Ahmed
6e00ba8bad
Enable more mesh shape assignment
...
We now sort the mesh dims by size first. Smaller dims have fewer choices so
they should be assigned first.
PiperOrigin-RevId: 520942700
2023-03-31 09:36:16 -07:00
jax authors
dfbbc2551c
Merge pull request #15317 from ROCmSoftwarePlatform:rocm_pmap_fix
...
PiperOrigin-RevId: 520934992
2023-03-31 09:05:07 -07:00
Peter Hawkins
abf1acf76c
Replace references to jax.interpreters with jax._src.interpreters in JAX core.
...
PiperOrigin-RevId: 520933067
2023-03-31 08:58:00 -07:00
jax authors
182cc9857c
Merge pull request #15323 from NeilGirdhar:fix_rayleigh
...
PiperOrigin-RevId: 520932851
2023-03-31 08:50:40 -07:00
Jake VanderPlas
9ec3ad1ce7
DOC: pin newest sphinx-book-theme
2023-03-31 08:42:34 -07:00
Jake VanderPlas
749dc1b95e
Remove deprecated function jnp.msort
2023-03-31 08:24:36 -07:00
jax authors
0df2ddcf0e
Merge pull request #15232 from gnecula:tf_arange
...
PiperOrigin-RevId: 520914838
2023-03-31 07:11:19 -07:00
George Necula
c368c69625
[shape_poly] Extend the handling of jnp.arange with shape polymorphism.
...
Previously, only `arange(stop, dtype=...)` was being handled in presence
of shape polymorphism. Here we extend to add support for `start` and `step`
to be also present. There are still plenty of restrictions:
* no floating point constants are allowed among start, stop and step
* we must resolve statically if step is positive or negative
* we must resolve statically if the distance between start and stop
is negative or positive.
2023-03-31 14:41:26 +02:00
jax authors
76b922aade
Merge pull request #15337 from mattjj:axis-name-shadowing-2
...
PiperOrigin-RevId: 520838748
2023-03-30 23:01:02 -07:00
Matthew Johnson
6a2b081506
fix bug from #15335 by checking main_trace tag
2023-03-30 22:35:03 -07:00
jax authors
12bcdeb69e
Merge pull request #15335 from mattjj:axis-name-shadowing
...
PiperOrigin-RevId: 520829991
2023-03-30 21:56:42 -07:00
Matthew Johnson
211bc29842
add assertions for axis name shadowing bugs
2023-03-30 21:31:02 -07:00
jax authors
d383ab65dc
Merge pull request #15255 from eltociear:patch-6
...
PiperOrigin-RevId: 520814903
2023-03-30 20:38:17 -07:00
jax authors
8e17da477c
Merge pull request #15322 from jakevdp:pre-commit
...
PiperOrigin-RevId: 520793950
2023-03-30 18:26:27 -07:00
jax authors
248ffc2ca2
Merge pull request #15329 from jakevdp:padfunc-protocol-2
...
PiperOrigin-RevId: 520793934
2023-03-30 18:19:43 -07:00