Peter Hawkins
dfe95dcb4e
Split ShardingSpecs and most of the helpers for constructing them into a separate file (jax/_src/sharding_specs.py).
...
PiperOrigin-RevId: 522360232
2023-04-06 09:48:51 -07:00
Yash Katariya
b8ade584bf
Add more multi device array slicing tests
...
PiperOrigin-RevId: 522345812
2023-04-06 08:45:36 -07:00
Peter Hawkins
452f3c55e3
Rename jax._src.sharding_utils to jax._src.op_shardings.
...
Move some more op_sharding related helpers to that module.
PiperOrigin-RevId: 522343010
2023-04-06 08:32:46 -07:00
jax authors
492b9c1455
Merge pull request #15397 from jakevdp:fix-split-annotation
...
PiperOrigin-RevId: 522341314
2023-04-06 08:23:54 -07:00
Yash Katariya
b926e04afc
Remove the shim of functions in sharding_utils from pxla.py and use those functions directly from sharding_utils in JAX
...
PiperOrigin-RevId: 522319332
2023-04-06 06:18:03 -07:00
jax authors
95525e7f8d
Merge pull request #15399 from mattjj:issue15398
...
PiperOrigin-RevId: 522255300
2023-04-05 23:23:27 -07:00
Yash Katariya
728a5ed96a
[shard-map] fix eager shmap+prngs, revise phys aval/sharding logic
...
Co-authored-by: Yash Katariya <yashkatariya@google.com>
2023-04-05 23:04:41 -07:00
jax authors
8ef3b04deb
Merge pull request #15428 from mattjj:jax-array
...
PiperOrigin-RevId: 522232242
2023-04-05 21:00:31 -07:00
Matthew Johnson
f9259f3f62
fix a __jax_array__
bug
2023-04-05 20:24:53 -07:00
Yash Katariya
03d5aaad96
Switch the implementation of sharded_aval to a simpler one.
...
Create sharding_utils.py to move utilities from pxla.py to sharding_utils.py to break cyclic deps.
PiperOrigin-RevId: 522209346
2023-04-05 18:32:00 -07:00
Jake VanderPlas
c10cb17751
Accelerate deprecation of jax.ShapedArray
...
This is deprecated as of https://github.com/google/jax/pull/15263 : most users will never need to use ShapedArray directly, and so having it exposed in the top-level public namespace causes undue confusion.
PiperOrigin-RevId: 522168275
2023-04-05 15:26:19 -07:00
Jake VanderPlas
8c8f50f688
Fix tolerance and shard_count for experimental_rnn_test
...
This should fix the current GPU test timeout.
PiperOrigin-RevId: 522167894
2023-04-05 15:19:19 -07:00
jax authors
f6da71c807
Merge pull request #15401 from mattjj:issue15400
...
PiperOrigin-RevId: 522115286
2023-04-05 11:55:47 -07:00
Peter Hawkins
29ba2ca926
Report the argument path when encountering an overflow error for a Python value.
...
PiperOrigin-RevId: 522106244
2023-04-05 11:24:40 -07:00
jax authors
7b4e579202
Merge pull request #15415 from jakevdp:sparse-add
...
PiperOrigin-RevId: 522101591
2023-04-05 11:07:50 -07:00
Jake VanderPlas
05f32a7947
[sparse] allow sparse-dense add when the output is the same size as dense input
2023-04-05 10:39:43 -07:00
John QiangZhang
0e549ac4be
Update unit test and doc how work around jit_compile=False on TPU for native_serialization.
...
PiperOrigin-RevId: 522077249
2023-04-05 09:44:34 -07:00
Matthew Johnson
ac4942d7f7
fix conj transpose on symbolic zero
...
fixes #15400
2023-04-04 20:45:21 -07:00
Peter Hawkins
bf50551e0f
Explicitly import jax.custom_{batching,derivatives,transpose}.
...
https://github.com/google/jax/pull/15391 had the unintentional side effect of causing these names not to be imported by default. Restore the status quo by importing them.
PiperOrigin-RevId: 521898088
2023-04-04 16:40:15 -07:00
jax authors
aab24fead0
Merge pull request #15396 from jakevdp:conv-elem-type
...
PiperOrigin-RevId: 521895290
2023-04-04 16:26:30 -07:00
Jake VanderPlas
aa643a895c
[typing] fix annotation of jnp.split
2023-04-04 16:17:21 -07:00
Jake VanderPlas
c2fe350455
future-proof lax.convert_element_type
...
In the future, np.array(large_value, 'int32') will error
2023-04-04 15:57:32 -07:00
Yash Katariya
ffa9d018d6
DCE as early as possible so that committed
is not dependent on DCE's vars
...
PiperOrigin-RevId: 521879918
2023-04-04 15:21:12 -07:00
Parker Schuh
9095faaeb0
Remove PyBuffer type and its bindings.
...
PiperOrigin-RevId: 521865179
2023-04-04 14:24:23 -07:00
Parker Schuh
6040580fa3
stages should not eagerly load the executables by calling cpp_call.
...
PiperOrigin-RevId: 521849296
2023-04-04 13:25:24 -07:00
Peter Hawkins
75d0f6522d
Add cupti pip dependency, needed for GPU profiling.
...
Issue https://github.com/google/jax/issues/15384
PiperOrigin-RevId: 521841461
2023-04-04 12:55:36 -07:00
Peter Hawkins
c1f65fc8b2
Avoid imports from the public jax.* namespace in more places internally.
...
This change is in preparation for more cycle breaking in the Bazel dependency graph.
PiperOrigin-RevId: 521822756
2023-04-04 11:41:40 -07:00
jax authors
3c1f3abba2
Merge pull request #15149 from sharadmv:runstate
...
PiperOrigin-RevId: 521809360
2023-04-04 10:56:25 -07:00
jax authors
efcd85daeb
Merge pull request #15086 from shoyer:extrap
...
PiperOrigin-RevId: 521796315
2023-04-04 10:12:13 -07:00
jax authors
9bb3d86647
Merge pull request #15390 from jakevdp:checkify-dynamic-slice
...
PiperOrigin-RevId: 521790925
2023-04-04 09:54:19 -07:00
Jake VanderPlas
46297dccaf
checkify: catch OOB errors in dynamic_slice
...
This will allow checkify tests to continue working properly after #15377
2023-04-04 08:16:59 -07:00
George Necula
35bfdc65e8
[shape_poly] Add some support for shape polymorphism for FFT, and tests
...
PiperOrigin-RevId: 521749241
2023-04-04 06:45:57 -07:00
Sharad Vikram
5101184ad4
Add initial implementation of a run_state
primitive
2023-04-03 21:32:32 -07:00
jax authors
8a6c929678
Merge pull request #15289 from cgarciae:add-missing-api-references
...
PiperOrigin-RevId: 521617419
2023-04-03 18:23:04 -07:00
jax authors
b361f4cd0c
Merge pull request #15169 from cgarciae:fix-lstm
...
PiperOrigin-RevId: 521616002
2023-04-03 18:13:19 -07:00
Yash Katariya
14b572f60d
Remove _compile_replicated option from compile
since it is not needed anymore and some other cosmetic fixes.
...
PiperOrigin-RevId: 521604489
2023-04-03 17:17:33 -07:00
Stephan Hoyer
4009005f0c
Support extrapolation in jnp.interp
...
Fixes https://github.com/google/jax/issues/14858
2023-04-03 15:31:14 -07:00
Cristian Garcia
aa12e3597b
handle seq_lengths in lstm_ref
2023-04-03 22:22:54 +00:00
Parker Schuh
c2b15a1eb8
Break out aot_test from array_test (for serialization and other aot APIs).
...
PiperOrigin-RevId: 521568985
2023-04-03 14:47:53 -07:00
Yash Katariya
78678ee9e1
Rename count_pjit_cache_miss
with count_pjit_cpp_cache_miss
because it is confusing which cache the first function is taking about as pjit has many caches
...
PiperOrigin-RevId: 521559652
2023-04-03 14:15:02 -07:00
Yash Katariya
6f2256ad17
Improve the error message of device_indices_map when the sharding is not divisible by the shape rather than raising an opaque assertion error
...
PiperOrigin-RevId: 521507810
2023-04-03 11:05:25 -07:00
George Necula
05249ec770
[jax2tf] Add more sharding tests with shape polymorphism
...
PiperOrigin-RevId: 521471546
2023-04-03 08:54:58 -07:00
George Necula
ff313a37a2
[jax2tf] Skip "graph" mode primitive tests on TPUs.
...
PiperOrigin-RevId: 521468145
2023-04-03 08:39:36 -07:00
jax authors
d743d23859
Convolution functions in TF, like- tf.nn.depthwise_conv2d_v2, tf.nn.conv2d_transpose_v2, tf.nn.conv2d_v2 all follow the same principal when it comes to padding(explained here- https://www.tensorflow.org/api_docs/python/tf/nn#notes_on_padding_2 ). These principal happens to match with that of lax.convolution.conv_general_dilated
.
...
So, this CL safely uses the padding(list[tuple(int, int)]) to call tf.nn.<conv> functions
PiperOrigin-RevId: 521464565
2023-04-03 08:22:02 -07:00
George Necula
607c7c1fdd
Plumbing for dynamic shapes for custom calls.
...
PiperOrigin-RevId: 521439418
2023-04-03 06:16:12 -07:00
jax authors
0d32724882
Merge pull request #15340 from gnecula:dim_vars3
...
PiperOrigin-RevId: 521424534
2023-04-03 04:44:16 -07:00
George Necula
cd35e901aa
[shape_poly] Cleanup handling of dimension variables.
...
We unify the way we compute with dimension variables (computing
their values from the shape of the actual arguments, and also
using those values to evaluate shapes that contain dimension variables).
We remove DimExprValueMlir, and all computations with dimension variables
and DimExpr are now done by JAX interpretation, followed by lowering to
TF or StableHLO.
2023-04-03 13:33:29 +02:00
George Necula
bf2c07121b
[jax2tf] Add test that compile_args[tuple_args] does not matter for serialization
...
PiperOrigin-RevId: 521422653
2023-04-03 04:32:34 -07:00
George Necula
2ce78ac9a8
[jax2tf] Add checks that we do not see unexpected lowered.compiler_args
...
Some of those compile_args change the semantics and the calling convention
for the lowered module. We want to be explicit about the ones that we
are handling.
PiperOrigin-RevId: 521419681
2023-04-03 04:13:31 -07:00
jax authors
b0a6cdbf24
Merge pull request #15341 from gnecula:tf_grad
...
PiperOrigin-RevId: 521409693
2023-04-03 03:15:36 -07:00