This is an implementation of a batch-friendly multidimensional COO format for JAX. It contains implementations of four primitives (bcoo_todense, bcoo_fromdense, bcoo_extract, bcoo_dot_general), as well as batching, JVP, and transpose rules for each.
For convenience, this also adds class BCOO, which is a pytree wrapper around these.
It turned out that, in jax2tf._dynamic_slice, tf.constant doesn't work with polymorphic shapes, so I replaced it with a tf.cast.
PiperOrigin-RevId: 378392273
The goal is to ensure that the HLO that
jax2tf->TF/XLA generates has the same metadata as what JAX generates.
This includes `op_type`, `op_name`, and source information, which are
used for debugging and profiling.
In order to ensure that this metadata is carried from the JAX tracing
time to TF/XLA, we save the metadata in custom TF op attributes. These
attributes are automatically preserved through SavedModel. This relies
on a separate change in TF/XLA to look for these custom attributes
and override its default.
For the source information, we use pretty much the same code that
xla.py uses. HLO OpMetadata has room for only one source location.
JAX (xla.py) picks the top-most user frame, which is obtained by
filtering out the stack frames in the JAX source tree. When used
with jax2tf we also need to filter out stack frames in the
TensorFlow source tree.
The hardest part is to generate the `op_name`, which is a hierarchical
name with components separated by '/', e.g., `jax2tf(top_func)/while/cond/le`.
We carry the current `name_stack` in thread-local state. Unfortunately, there
is no easy way to share the exact code that achieves this in xla.py. At the
same time it is not crucial that we have exactly identical name stacks as in
JAX.
I attempted to also carry this state in the JAX `MainTrace`, but could not
fully control the name stack. E.g., when calling a jitted-function we
have to reuse the current `MainTrace` although we want to push an element
on the name stack.
For now this option is not yet enabled until we make the necessary
changes in TensorFlow.
complex128 isn't supported on TPUs in TF, tf.constant now places on TPU by default, _is_tfval saw the exception and assumed it wasn't convertable to a TF type.
PiperOrigin-RevId: 378240447
This is especially useful because it makes it easy to implement
"memory-limited vmaps". It might also come in handy for pipelining,
as that could represent the microbatching loop.
Note that at the moment the xmap has to be a pure map along all axes
that are assigned to loop resources. No collectives are supported.