12952 Commits

Author SHA1 Message Date
Yash Katariya
b7e4e44cbf DCE jaxpr and trivial_jaxpr support for lower_sharding_computation
Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 471274989
2022-09-06 14:09:10 -07:00
jax authors
e9204e312a Merge pull request #12215 from sharadmv:for-loop-remat
PiperOrigin-RevId: 472536171
2022-09-06 13:04:03 -07:00
jax authors
5506d3fc85 Merge pull request #12223 from gnecula:ds_lower_aval
PiperOrigin-RevId: 472530127
2022-09-06 12:44:51 -07:00
jax authors
f3cefe10e5 Merge pull request #12172 from sudhakarsingh27:separate_marker_for_multinode_gpu_tests
PiperOrigin-RevId: 472530121
2022-09-06 12:38:31 -07:00
Sudhakar
5f1858f533 Add pytest marker inside the test only if pytest is present in the env 2022-09-06 11:45:59 -07:00
Jon Barron
dc4591dd6c Fix NaNs in the gradient of jnp.interp when the spline being interpolated into contains knots that are small and nearby.
PiperOrigin-RevId: 472511203
2022-09-06 11:22:47 -07:00
Sharad Vikram
b2a5d2c3bb Add partial_eval_custom rule for for_loop 2022-09-06 11:00:26 -07:00
Ilia Sergachev
e7ddb2f5a2 Use reduce_window for faster cumulative reductions on small inputs on GPU.
On GPU, reduce_window is executed in a single fusion and associative_scan
is split into multiple to materialize intermediate calculations.
On small inputs reduce_window is faster being a single fusion,
but on larger ones is slower because of O(n^2) complexity.
The conservative value of the threshold to choose between the two algorithms was obtained by benchmarking.
cumred_tpu_impl was renamed into cumred_reduce_window_impl to reflect that it is useful not only on TPUs.

PiperOrigin-RevId: 472504553
2022-09-06 10:58:00 -07:00
jax authors
bfab12a455 Merge pull request #12232 from LenaMartens:not-checkify
PiperOrigin-RevId: 472501061
2022-09-06 10:45:13 -07:00
reinerp
748ba40574 Fix some cases where deserialization hangs:
* Make sure to release the semaphore in the dtype!=None case.
* If chunks are too large to _ever_ acquire the semaphore, immediately raise an
  error.

PiperOrigin-RevId: 472496109
2022-09-06 10:27:57 -07:00
jax authors
ba15118bb4 Merge pull request #12225 from PhilipVinc:patch-1
PiperOrigin-RevId: 472462967
2022-09-06 08:24:57 -07:00
jax authors
048f28055f Merge pull request #12230 from gnecula:tf_refactor_new
PiperOrigin-RevId: 472451131
2022-09-06 07:19:53 -07:00
lenamartens
82f74d1898 Checkify: Remove some err <-not-> pred flipping.
Now all internal APIs deal in terms of errs (error thrown if True), and
only the check API takes a pred (error thrown if False).
2022-09-06 14:59:59 +01:00
George Necula
56c2c0baa8 [jax2tf] Refactor top-level jax2tf.convert
There are several goals for this refactoring:
  * improve the readability of the code: more helper functions, move big
    nested functions to top-level to make make it obvious what are the
    data dependencies
  * try to be more systematic about naming: JAX entities end with _jax
    and TF entities with _tf. This is helpful because in several cases
    one function has to operate with both kinds of entities.
  * the main goal is to enable fixing the experimental_native_lowering
    for pjit. For that (future) work, we want to pass JAX callables
    to _interpret_fun_jax, rather than linear_util.WrappedFun. Then
    we can use the standard AOT APIs.

This was initially reviewed and submitted as #12205,
but was rolled back due to test failures.
2022-09-06 16:58:11 +03:00
Marc van Zee
a1e4f68a73 [jax2tf] Rewrites converters_eval framework.
Resulting table: ebdfb60ee8/jax/experimental/jax2tf/g3doc/convert_models_results.md

* Makes testing models more similar to testing primitives: Moves the framework into `jax2tf/tests`, the main file to call now is `model_test.py` and the models themselves are now in `model_harness.py`.

* Moves the g3doc to `jax2tf/g3doc`

* Simplifies conversion and testing logic.

* Adds more converters and improves the output in the g3doc.

* Fixes various bugs in the conversion. The errors shown now are all problems with the actual converters.

PiperOrigin-RevId: 472437502
2022-09-06 06:04:41 -07:00
Jon Barron
892dea5d87 Fix NaNs in the gradient of jnp.interp when the spline being interpolated into contains the same knot coordinate twice.
PiperOrigin-RevId: 472333594
2022-09-05 16:16:55 -07:00
jax authors
70c339ea2c Internal change
PiperOrigin-RevId: 472301936
2022-09-05 10:49:24 -07:00
jax authors
98429b6181 Merge pull request #12205 from gnecula:tf_refactor
PiperOrigin-RevId: 472293030
2022-09-05 09:39:19 -07:00
George Necula
33d1c08a31 Update jax/experimental/jax2tf/jax2tf.py
Co-authored-by: Marc van Zee <marcvanzee@gmail.com>
2022-09-05 16:33:34 +03:00
George Necula
9e2fa4d24d [jax2tf] Refactor top-level jax2tf.convert
There are several goals for this refactoring:
  * improve the readability of the code: more helper functions, move big
    nested functions to top-level make make it obvious what are the
    data dependencies
  * try to be more systematic about naming: JAX entities end with _jax
    and TF entities with _tf. This is helpful because in several cases
    one function has to operate with both kinds of entities.
  * the main goal is to enable fixing the experimental_native_lowering
    for pjit. For that (future) work, we want to pass JAX callables
    to _interpret_fun_jax, rather than linear_util.WrappedFun. Then
    we can use the standard AOT APIs.
2022-09-05 16:33:34 +03:00
Filippo Vicentini
52236adeed
Show link to GitHub repo in navbar 2022-09-05 13:41:49 +02:00
jax authors
4ae5cb31b0 Merge pull request #12196 from marcvanzee:tfjs
PiperOrigin-RevId: 472206922
2022-09-04 23:41:17 -07:00
George Necula
fe055d06ba Allow get_aval to work on ShapeDtypeStruct
This is necessary to be able to call jit(f).lower(ShapeDtypeStruct(...) when
--jax_dynamic_shapes is on. The code in partial_eval.infer_lambda_input_type
calls get_aval.
2022-09-04 12:11:05 +03:00
jax authors
15fbf22715 Merge pull request #12204 from froystig:aot-docs
PiperOrigin-RevId: 471897501
2022-09-02 15:23:17 -07:00
Kuangyuan Chen
d17e516ea7 Add benchmarks for jax.Array
PiperOrigin-RevId: 471889808
2022-09-02 14:45:09 -07:00
Roy Frostig
a2ad414e7c mention AOT readiness in changelog 2022-09-02 13:02:25 -07:00
Roy Frostig
bb68fbeefa write in-process AOT walkthrough doc 2022-09-02 13:02:25 -07:00
Victor Stone
bd425e5dc5 Rolling back due to breaking tests
PiperOrigin-RevId: 471847099
2022-09-02 11:29:14 -07:00
Marc van Zee
7cc7da8576 Improves quickdraw example in jax2tf 2022-09-02 10:54:27 +02:00
Yash Katariya
3e54ac0af0 Make __iter__ of Array behave like DA when there is a SingleDeviceSharding and like SDA when there is a non-trivial sharding.
This is important because when `Array` contains more than 1 shard, each shard can be on a different device and those things need to be preserved when iterating over `Array`.

PiperOrigin-RevId: 471695841
2022-09-01 19:54:34 -07:00
Roy Frostig
43db06491c write and generate package API documentation for jax.stages 2022-09-01 19:26:53 -07:00
Roy Frostig
4505d57a60 docstring for jax.stages.Wrapped 2022-09-01 18:31:38 -07:00
jax authors
34ce471a44 Merge pull request #12203 from sharadmv:pure-callback-jit
PiperOrigin-RevId: 471678388
2022-09-01 17:49:19 -07:00
jax authors
76a4494027 Merge pull request #12201 from jakevdp:bcoo-dynamic-slice
PiperOrigin-RevId: 471665862
2022-09-01 16:39:17 -07:00
Victor Stone
de876a7ed9 Enable cuBLASLt by default in XLA for most matmuls
Enable cuBLASLt by default in XLA with two exceptions. First, the current XLA implementation using cublasLt does not yet support int8 gemms. Second, the cublasLt api does not support a certain dimension size larger than a specific value; in this case we fallback to legacy cublas. This change makes a modification so that we prefer to do the cublaslt gemm operation in place when fusing with a bias add. Updated JAX test precision for new matmul results.

PiperOrigin-RevId: 471661566
2022-09-01 16:18:26 -07:00
Sharad Vikram
e1410bd16b Use lowering as impl rule for pure_callback 2022-09-01 15:29:31 -07:00
Jake VanderPlas
47b9f216bc [sparse] add sparse support for dynamic_slice 2022-09-01 13:42:02 -07:00
jax authors
0869183107 Merge pull request #12200 from froystig:key-scalar-only-ops
PiperOrigin-RevId: 471620603
2022-09-01 13:41:25 -07:00
Roy Frostig
5ac8043860 require scalar key arrays when seeding, splitting, and folding
Recent changes to RNG internals actually make it easier for us to
render these operations batch-polymorphic. However, any existing use
of these in a non-scalar way suggests incorrect usage, since they were
scalar-only before (albeit imperfectly guarded as such).
2022-09-01 13:02:47 -07:00
jax authors
1dbfa8893c Merge pull request #12178 from gnecula:tf_repeat
PiperOrigin-RevId: 471526861
2022-09-01 07:36:22 -07:00
George Necula
906a212b80 [shape_poly] Add limited support for jnp.repeat with shape polymorphism
Supports only the case when the `repeats` parameter is a dimension polynomials
and the total_repeat_length is None.
2022-09-01 01:09:52 -07:00
Yash Katariya
0584c6a1c4 Add support to handle arbitrary shardings to KeyArray. Resolve all the TODOs that were created before.
Co-authored-by: Roy Frostig <frostig@google.com>
PiperOrigin-RevId: 471443690
2022-08-31 22:54:06 -07:00
jax authors
bf7525e121 Merge pull request #12170 from froystig:just-dtype
PiperOrigin-RevId: 471409020
2022-08-31 18:36:47 -07:00
Yash Katariya
6eb80fb0e6 Add fast path args to Array similar to GDA to speed up initialization and other operations like calculating indices and addressable_device_assignment.
This is important because looping over 1000s of devices is extremely expensive during runtime and throttles the performance (all these optimizations were applied to GDA when integrating it into PAX and are applicable to Array as well). This will also be helpful for single-controller environments.

Also even hashing and __eq__ checks when you have 1000s of devices is going to be slow and will show up in xprof as a slowdown (I have seen this before).

PiperOrigin-RevId: 471366295
2022-08-31 15:11:44 -07:00
John QiangZhang
59c2fc9ea9 [jax2tf] add a new test for jax2tf gda test.
Now it cover the test using gda as jax function input.

PiperOrigin-RevId: 471365834
2022-08-31 15:05:49 -07:00
Yash Katariya
2f7951b3dc Add __hash__ and __eq__ to PmapSharding
PiperOrigin-RevId: 471356052
2022-08-31 14:27:16 -07:00
jax authors
e1b250caf4 Merge pull request #12186 from sharadmv:fix-64-dtype
PiperOrigin-RevId: 471350419
2022-08-31 14:08:15 -07:00
jax authors
fc7c8de3f9 Merge pull request #12184 from froystig:key-array-pickle
PiperOrigin-RevId: 471327279
2022-08-31 12:52:47 -07:00
Sharad Vikram
311a9cb5d9 Throw error when 64-bit dtypes used incorrectly in jax.pure_callback 2022-08-31 12:31:04 -07:00
jax authors
e98eb442fa Merge pull request #12177 from froystig:test-rng-upgrade
PiperOrigin-RevId: 471317729
2022-08-31 12:17:16 -07:00