On GPU, reduce_window is executed in a single fusion and associative_scan
is split into multiple to materialize intermediate calculations.
On small inputs reduce_window is faster being a single fusion,
but on larger ones is slower because of O(n^2) complexity.
The conservative value of the threshold to choose between the two algorithms was obtained by benchmarking.
cumred_tpu_impl was renamed into cumred_reduce_window_impl to reflect that it is useful not only on TPUs.
PiperOrigin-RevId: 472504553
* Make sure to release the semaphore in the dtype!=None case.
* If chunks are too large to _ever_ acquire the semaphore, immediately raise an
error.
PiperOrigin-RevId: 472496109
There are several goals for this refactoring:
* improve the readability of the code: more helper functions, move big
nested functions to top-level to make make it obvious what are the
data dependencies
* try to be more systematic about naming: JAX entities end with _jax
and TF entities with _tf. This is helpful because in several cases
one function has to operate with both kinds of entities.
* the main goal is to enable fixing the experimental_native_lowering
for pjit. For that (future) work, we want to pass JAX callables
to _interpret_fun_jax, rather than linear_util.WrappedFun. Then
we can use the standard AOT APIs.
This was initially reviewed and submitted as #12205,
but was rolled back due to test failures.
Resulting table: ebdfb60ee8/jax/experimental/jax2tf/g3doc/convert_models_results.md
* Makes testing models more similar to testing primitives: Moves the framework into `jax2tf/tests`, the main file to call now is `model_test.py` and the models themselves are now in `model_harness.py`.
* Moves the g3doc to `jax2tf/g3doc`
* Simplifies conversion and testing logic.
* Adds more converters and improves the output in the g3doc.
* Fixes various bugs in the conversion. The errors shown now are all problems with the actual converters.
PiperOrigin-RevId: 472437502
There are several goals for this refactoring:
* improve the readability of the code: more helper functions, move big
nested functions to top-level make make it obvious what are the
data dependencies
* try to be more systematic about naming: JAX entities end with _jax
and TF entities with _tf. This is helpful because in several cases
one function has to operate with both kinds of entities.
* the main goal is to enable fixing the experimental_native_lowering
for pjit. For that (future) work, we want to pass JAX callables
to _interpret_fun_jax, rather than linear_util.WrappedFun. Then
we can use the standard AOT APIs.
This is necessary to be able to call jit(f).lower(ShapeDtypeStruct(...) when
--jax_dynamic_shapes is on. The code in partial_eval.infer_lambda_input_type
calls get_aval.
This is important because when `Array` contains more than 1 shard, each shard can be on a different device and those things need to be preserved when iterating over `Array`.
PiperOrigin-RevId: 471695841
Enable cuBLASLt by default in XLA with two exceptions. First, the current XLA implementation using cublasLt does not yet support int8 gemms. Second, the cublasLt api does not support a certain dimension size larger than a specific value; in this case we fallback to legacy cublas. This change makes a modification so that we prefer to do the cublaslt gemm operation in place when fusing with a bias add. Updated JAX test precision for new matmul results.
PiperOrigin-RevId: 471661566
Recent changes to RNG internals actually make it easier for us to
render these operations batch-polymorphic. However, any existing use
of these in a non-scalar way suggests incorrect usage, since they were
scalar-only before (albeit imperfectly guarded as such).
This is important because looping over 1000s of devices is extremely expensive during runtime and throttles the performance (all these optimizations were applied to GDA when integrating it into PAX and are applicable to Array as well). This will also be helpful for single-controller environments.
Also even hashing and __eq__ checks when you have 1000s of devices is going to be slow and will show up in xprof as a slowdown (I have seen this before).
PiperOrigin-RevId: 471366295