To solve a circular dependency problem where some functions in jax._src.lax.lax depend on slicing, I moved a number of utility functions, e.g., standard_primitive, into a new module `jax._src.lax.utils`. Only utilities that need to be present at module import time were moved.
PiperOrigin-RevId: 411921794
Previously, jax.jit was ignored by jax2tf. This can result in the
converted code being much slower than the JAX core, unless the
user adds an explicit `tf.function(jit_compile=True)`. With this
change that wrapper is added automatically for all code fragments
under jax.jit. Note that most jax.numpy functions are annotated
with jax.jit, so with this change they will all be compiled.
When doing this I ran into problems with tf.custom_gradient and
tf.function. As documented in the
[tf.custom_gradient](https://www.tensorflow.org/api_docs/python/tf/custom_gradient)
documentation, you get a LookupError when trying to build the gradient
of a tf.function, even if it has a tf.custom_gradient defined. The
recommended solution is to add a tf.stop_gradient. This is safe, since
jax2tf will always wrap the converted functions with a tf.custom_gradient.
This is similar to the support in lax.reduce(), where the operands and init_values become pytrees. This is a strict superset of the current API, so users should not need updates.
Variadic lax.reduce_window() is only supported on CPU and TPU at the moment, not GPU.
PiperOrigin-RevId: 411632993
https://github.com/google/jax/pull/8606 introduced a runtime error where as a consequence of the move, a reference to `slice` became a reference to the builtin slice operator instead of `lax.slice`.
After fixing that and while added a test, I noticed that the gradient was wrong before: we should have been slicing the result, not the operand in the transpose rule's handling of base dilation.
Also enable some TPU tests that now pass since we have variadic reduce-window support on TPU.
PiperOrigin-RevId: 411579650
* jax._src.device_array, which contains the definition of DeviceArray.
* jax.interpreters.xla, which contains code for lowering jaxprs into XLA computations.
* jax._src.dispatch, which contains code for executing primitives and jit-compiled functions (xla_call_p's impl logic).
The purpose of splitting up this file is that I would like to treat jax.interpreters.mlir lowering as an alternative to jax.interpreters.xla, but we wish to share the device_array and computation dispatch pieces. Currently jax.interpreters.mlir duplicates most of the dispatch logic. (That refactoring is for a future change; this change just moves the existing code around.)
PiperOrigin-RevId: 411565432
The only reason jax2tf needs access to the internals of jax.lax is when it wants to reuse various translation rule helpers; keep those as explicit internal imports.
This change is partially to minimize churn as jax._src.lax is restructured.
PiperOrigin-RevId: 411276891