15605 Commits

Author SHA1 Message Date
Peter Hawkins
dfe95dcb4e Split ShardingSpecs and most of the helpers for constructing them into a separate file (jax/_src/sharding_specs.py).
PiperOrigin-RevId: 522360232
2023-04-06 09:48:51 -07:00
Yash Katariya
b8ade584bf Add more multi device array slicing tests
PiperOrigin-RevId: 522345812
2023-04-06 08:45:36 -07:00
Peter Hawkins
452f3c55e3 Rename jax._src.sharding_utils to jax._src.op_shardings.
Move some more op_sharding related helpers to that module.

PiperOrigin-RevId: 522343010
2023-04-06 08:32:46 -07:00
jax authors
492b9c1455 Merge pull request #15397 from jakevdp:fix-split-annotation
PiperOrigin-RevId: 522341314
2023-04-06 08:23:54 -07:00
Yash Katariya
b926e04afc Remove the shim of functions in sharding_utils from pxla.py and use those functions directly from sharding_utils in JAX
PiperOrigin-RevId: 522319332
2023-04-06 06:18:03 -07:00
jax authors
95525e7f8d Merge pull request #15399 from mattjj:issue15398
PiperOrigin-RevId: 522255300
2023-04-05 23:23:27 -07:00
Yash Katariya
728a5ed96a [shard-map] fix eager shmap+prngs, revise phys aval/sharding logic
Co-authored-by: Yash Katariya <yashkatariya@google.com>
2023-04-05 23:04:41 -07:00
jax authors
8ef3b04deb Merge pull request #15428 from mattjj:jax-array
PiperOrigin-RevId: 522232242
2023-04-05 21:00:31 -07:00
Matthew Johnson
f9259f3f62 fix a __jax_array__ bug 2023-04-05 20:24:53 -07:00
Yash Katariya
03d5aaad96 Switch the implementation of sharded_aval to a simpler one.
Create sharding_utils.py to move utilities from pxla.py to sharding_utils.py to break cyclic deps.

PiperOrigin-RevId: 522209346
2023-04-05 18:32:00 -07:00
Jake VanderPlas
c10cb17751 Accelerate deprecation of jax.ShapedArray
This is deprecated as of https://github.com/google/jax/pull/15263: most users will never need to use ShapedArray directly, and so having it exposed in the top-level public namespace causes undue confusion.

PiperOrigin-RevId: 522168275
2023-04-05 15:26:19 -07:00
Jake VanderPlas
8c8f50f688 Fix tolerance and shard_count for experimental_rnn_test
This should fix the current GPU test timeout.

PiperOrigin-RevId: 522167894
2023-04-05 15:19:19 -07:00
jax authors
f6da71c807 Merge pull request #15401 from mattjj:issue15400
PiperOrigin-RevId: 522115286
2023-04-05 11:55:47 -07:00
Peter Hawkins
29ba2ca926 Report the argument path when encountering an overflow error for a Python value.
PiperOrigin-RevId: 522106244
2023-04-05 11:24:40 -07:00
jax authors
7b4e579202 Merge pull request #15415 from jakevdp:sparse-add
PiperOrigin-RevId: 522101591
2023-04-05 11:07:50 -07:00
Jake VanderPlas
05f32a7947 [sparse] allow sparse-dense add when the output is the same size as dense input 2023-04-05 10:39:43 -07:00
John QiangZhang
0e549ac4be Update unit test and doc how work around jit_compile=False on TPU for native_serialization.
PiperOrigin-RevId: 522077249
2023-04-05 09:44:34 -07:00
Matthew Johnson
ac4942d7f7 fix conj transpose on symbolic zero
fixes #15400
2023-04-04 20:45:21 -07:00
Peter Hawkins
bf50551e0f Explicitly import jax.custom_{batching,derivatives,transpose}.
https://github.com/google/jax/pull/15391 had the unintentional side effect of causing these names not to be imported by default. Restore the status quo by importing them.

PiperOrigin-RevId: 521898088
2023-04-04 16:40:15 -07:00
jax authors
aab24fead0 Merge pull request #15396 from jakevdp:conv-elem-type
PiperOrigin-RevId: 521895290
2023-04-04 16:26:30 -07:00
Jake VanderPlas
aa643a895c [typing] fix annotation of jnp.split 2023-04-04 16:17:21 -07:00
Jake VanderPlas
c2fe350455 future-proof lax.convert_element_type
In the future, np.array(large_value, 'int32') will error
2023-04-04 15:57:32 -07:00
Yash Katariya
ffa9d018d6 DCE as early as possible so that committed is not dependent on DCE's vars
PiperOrigin-RevId: 521879918
2023-04-04 15:21:12 -07:00
Parker Schuh
9095faaeb0 Remove PyBuffer type and its bindings.
PiperOrigin-RevId: 521865179
2023-04-04 14:24:23 -07:00
Parker Schuh
6040580fa3 stages should not eagerly load the executables by calling cpp_call.
PiperOrigin-RevId: 521849296
2023-04-04 13:25:24 -07:00
Peter Hawkins
75d0f6522d Add cupti pip dependency, needed for GPU profiling.
Issue https://github.com/google/jax/issues/15384

PiperOrigin-RevId: 521841461
2023-04-04 12:55:36 -07:00
Peter Hawkins
c1f65fc8b2 Avoid imports from the public jax.* namespace in more places internally.
This change is in preparation for more cycle breaking in the Bazel dependency graph.

PiperOrigin-RevId: 521822756
2023-04-04 11:41:40 -07:00
jax authors
3c1f3abba2 Merge pull request #15149 from sharadmv:runstate
PiperOrigin-RevId: 521809360
2023-04-04 10:56:25 -07:00
jax authors
efcd85daeb Merge pull request #15086 from shoyer:extrap
PiperOrigin-RevId: 521796315
2023-04-04 10:12:13 -07:00
jax authors
9bb3d86647 Merge pull request #15390 from jakevdp:checkify-dynamic-slice
PiperOrigin-RevId: 521790925
2023-04-04 09:54:19 -07:00
Jake VanderPlas
46297dccaf checkify: catch OOB errors in dynamic_slice
This will allow checkify tests to continue working properly after #15377
2023-04-04 08:16:59 -07:00
George Necula
35bfdc65e8 [shape_poly] Add some support for shape polymorphism for FFT, and tests
PiperOrigin-RevId: 521749241
2023-04-04 06:45:57 -07:00
Sharad Vikram
5101184ad4 Add initial implementation of a run_state primitive 2023-04-03 21:32:32 -07:00
jax authors
8a6c929678 Merge pull request #15289 from cgarciae:add-missing-api-references
PiperOrigin-RevId: 521617419
2023-04-03 18:23:04 -07:00
jax authors
b361f4cd0c Merge pull request #15169 from cgarciae:fix-lstm
PiperOrigin-RevId: 521616002
2023-04-03 18:13:19 -07:00
Yash Katariya
14b572f60d Remove _compile_replicated option from compile since it is not needed anymore and some other cosmetic fixes.
PiperOrigin-RevId: 521604489
2023-04-03 17:17:33 -07:00
Stephan Hoyer
4009005f0c Support extrapolation in jnp.interp
Fixes https://github.com/google/jax/issues/14858
2023-04-03 15:31:14 -07:00
Cristian Garcia
aa12e3597b handle seq_lengths in lstm_ref 2023-04-03 22:22:54 +00:00
Parker Schuh
c2b15a1eb8 Break out aot_test from array_test (for serialization and other aot APIs).
PiperOrigin-RevId: 521568985
2023-04-03 14:47:53 -07:00
Yash Katariya
78678ee9e1 Rename count_pjit_cache_miss with count_pjit_cpp_cache_miss because it is confusing which cache the first function is taking about as pjit has many caches
PiperOrigin-RevId: 521559652
2023-04-03 14:15:02 -07:00
Yash Katariya
6f2256ad17 Improve the error message of device_indices_map when the sharding is not divisible by the shape rather than raising an opaque assertion error
PiperOrigin-RevId: 521507810
2023-04-03 11:05:25 -07:00
George Necula
05249ec770 [jax2tf] Add more sharding tests with shape polymorphism
PiperOrigin-RevId: 521471546
2023-04-03 08:54:58 -07:00
George Necula
ff313a37a2 [jax2tf] Skip "graph" mode primitive tests on TPUs.
PiperOrigin-RevId: 521468145
2023-04-03 08:39:36 -07:00
jax authors
d743d23859 Convolution functions in TF, like- tf.nn.depthwise_conv2d_v2, tf.nn.conv2d_transpose_v2, tf.nn.conv2d_v2 all follow the same principal when it comes to padding(explained here- https://www.tensorflow.org/api_docs/python/tf/nn#notes_on_padding_2). These principal happens to match with that of lax.convolution.conv_general_dilated.
So, this CL safely uses the padding(list[tuple(int, int)]) to call tf.nn.<conv> functions

PiperOrigin-RevId: 521464565
2023-04-03 08:22:02 -07:00
George Necula
607c7c1fdd Plumbing for dynamic shapes for custom calls.
PiperOrigin-RevId: 521439418
2023-04-03 06:16:12 -07:00
jax authors
0d32724882 Merge pull request #15340 from gnecula:dim_vars3
PiperOrigin-RevId: 521424534
2023-04-03 04:44:16 -07:00
George Necula
cd35e901aa [shape_poly] Cleanup handling of dimension variables.
We unify the way we compute with dimension variables (computing
their values from the shape of the actual arguments, and also
using those values to evaluate shapes that contain dimension variables).

We remove DimExprValueMlir, and all computations with dimension variables
and DimExpr are now done by JAX interpretation, followed by lowering to
TF or StableHLO.
2023-04-03 13:33:29 +02:00
George Necula
bf2c07121b [jax2tf] Add test that compile_args[tuple_args] does not matter for serialization
PiperOrigin-RevId: 521422653
2023-04-03 04:32:34 -07:00
George Necula
2ce78ac9a8 [jax2tf] Add checks that we do not see unexpected lowered.compiler_args
Some of those compile_args change the semantics and the calling convention
for the lowered module. We want to be explicit about the ones that we
are handling.

PiperOrigin-RevId: 521419681
2023-04-03 04:13:31 -07:00
jax authors
b0a6cdbf24 Merge pull request #15341 from gnecula:tf_grad
PiperOrigin-RevId: 521409693
2023-04-03 03:15:36 -07:00