There have been some recent breakages affecting the nightly driver,
causing JAX operations to fail on Cloud TPU Colabs. Pinning to a
specific version will alleviate these problems. This version may need
to be updated if there are breaking changes to the tpu_driver
client/server boundary, but that doesn't happen very often.
XLA itself does not consume these, but they can be propagated onto scatter() when computing gradients.
Compute unique/sorted information on indexed accesses and indexed updates. Non-advanced indexes are always sorted and unique.
This is an implementation of a batch-friendly multidimensional COO format for JAX. It contains implementations of four primitives (bcoo_todense, bcoo_fromdense, bcoo_extract, bcoo_dot_general), as well as batching, JVP, and transpose rules for each.
For convenience, this also adds class BCOO, which is a pytree wrapper around these.
It turned out that, in jax2tf._dynamic_slice, tf.constant doesn't work with polymorphic shapes, so I replaced it with a tf.cast.
PiperOrigin-RevId: 378392273
The goal is to ensure that the HLO that
jax2tf->TF/XLA generates has the same metadata as what JAX generates.
This includes `op_type`, `op_name`, and source information, which are
used for debugging and profiling.
In order to ensure that this metadata is carried from the JAX tracing
time to TF/XLA, we save the metadata in custom TF op attributes. These
attributes are automatically preserved through SavedModel. This relies
on a separate change in TF/XLA to look for these custom attributes
and override its default.
For the source information, we use pretty much the same code that
xla.py uses. HLO OpMetadata has room for only one source location.
JAX (xla.py) picks the top-most user frame, which is obtained by
filtering out the stack frames in the JAX source tree. When used
with jax2tf we also need to filter out stack frames in the
TensorFlow source tree.
The hardest part is to generate the `op_name`, which is a hierarchical
name with components separated by '/', e.g., `jax2tf(top_func)/while/cond/le`.
We carry the current `name_stack` in thread-local state. Unfortunately, there
is no easy way to share the exact code that achieves this in xla.py. At the
same time it is not crucial that we have exactly identical name stacks as in
JAX.
I attempted to also carry this state in the JAX `MainTrace`, but could not
fully control the name stack. E.g., when calling a jitted-function we
have to reuse the current `MainTrace` although we want to push an element
on the name stack.
For now this option is not yet enabled until we make the necessary
changes in TensorFlow.
complex128 isn't supported on TPUs in TF, tf.constant now places on TPU by default, _is_tfval saw the exception and assumed it wasn't convertable to a TF type.
PiperOrigin-RevId: 378240447
This is especially useful because it makes it easy to implement
"memory-limited vmaps". It might also come in handy for pipelining,
as that could represent the microbatching loop.
Note that at the moment the xmap has to be a pure map along all axes
that are assigned to loop resources. No collectives are supported.
This should let us emit good XLA annotations for `xmap(pjit)`. Previously
we might have been overestimating the set of replicated mesh dimensions.
PiperOrigin-RevId: 377259226
Previously we allowed a dimension variable in lieu of a dimension. Now we
allow multi-variate dimension polynomials. These polynomials overload addition, subtraction,
multiplication. They also partially support equality and inequality checking.
Equality and inequality are supported only when the operation result is the
same for all valuations of variables greater than 0. For example, `a == a`,
`a * b + 1 == 1 + b * a`, `a >= 1`, `2 * a + b >= 3`, `a >= a`. However, for
the following a `core.InconclusiveDimensionOperation` is raised: `a = b`, `a
>= 2`.
Division is supported only in the cases when either there is no remainder,
or the divisor is a constant.
This change allows us to support more general cases of `jnp.reshape(-1)`,
such as those used in the internal implementation of `random_gamma`:
```
y = x.reshape((2, -1))
z = ... y ...
return z.reshape(x.shape)
```
JAX and TensorFlow have different behavior w.r.t. 32-64 bit
computations. This PR cleans up the handling of types in jax2tf
to ensure that we follow the same behavior in jax2tf and in JAX.
This means that f_jax(args) always does the computation with the
same precision as jax2tf.convert(f_jax)(args). This may mean that
the result of the conversion depends on the value of JAX_ENABLE_x64.
See README.md for more details.