The goal is to ensure that the HLO that
jax2tf->TF/XLA generates has the same metadata as what JAX generates.
This includes `op_type`, `op_name`, and source information, which are
used for debugging and profiling.
In order to ensure that this metadata is carried from the JAX tracing
time to TF/XLA, we save the metadata in custom TF op attributes. These
attributes are automatically preserved through SavedModel. This relies
on a separate change in TF/XLA to look for these custom attributes
and override its default.
For the source information, we use pretty much the same code that
xla.py uses. HLO OpMetadata has room for only one source location.
JAX (xla.py) picks the top-most user frame, which is obtained by
filtering out the stack frames in the JAX source tree. When used
with jax2tf we also need to filter out stack frames in the
TensorFlow source tree.
The hardest part is to generate the `op_name`, which is a hierarchical
name with components separated by '/', e.g., `jax2tf(top_func)/while/cond/le`.
We carry the current `name_stack` in thread-local state. Unfortunately, there
is no easy way to share the exact code that achieves this in xla.py. At the
same time it is not crucial that we have exactly identical name stacks as in
JAX.
I attempted to also carry this state in the JAX `MainTrace`, but could not
fully control the name stack. E.g., when calling a jitted-function we
have to reuse the current `MainTrace` although we want to push an element
on the name stack.
For now this option is not yet enabled until we make the necessary
changes in TensorFlow.
JAX and TensorFlow have different behavior w.r.t. 32-64 bit
computations. This PR cleans up the handling of types in jax2tf
to ensure that we follow the same behavior in jax2tf and in JAX.
This means that f_jax(args) always does the computation with the
same precision as jax2tf.convert(f_jax)(args). This may mean that
the result of the conversion depends on the value of JAX_ENABLE_x64.
See README.md for more details.
Previously, the global enable_xla flag was set upon entry to
`jax.convert`. It should instead be set only for the duration
of the just-in-time conversion, which may happen later when
the converted function is invoked.
--
1ecf4f02891cad70cc8f094b49cf2458105ca366 by George Necula <gcnecula@gmail.com>:
[jax2tf] Change the conversion of dot_general to use XLA op.
Instead of converting the dot_general to a sea of TF ops, when
we enable_xla we just use the XLA op. This has the advantage
that it also supports the preferred_element_type.
Fixed bug with passing the precision parameter to TF.
Also improved tests to print the HLO in case of numerical errors.
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/6717 from gnecula:tf_dot 1ecf4f02891cad70cc8f094b49cf2458105ca366
PiperOrigin-RevId: 373326655
Update XLA.
CUDA 11.1 wheels are compatible with CUDA versions 11.1+, since NVidia now promises enhanced version compatibility between CUDA minor releases starting with CUDA 11.1
We're switching to the new terminology to avoid confusion in cases
where multiple jax processes are running on a single host, and each
process has a unique process_index/host_id.
This keeps aliases for the old `host_id` APIs for now, but these will
eventually be removed.
This was originally commited in
b77ef5138b631378e6a8ceb8bafc94fe91239bae, but reverted in
14acd070c2afb11c81fc91f43790577cd48cbf67 due to Google-internal test
failures from renaming the local_devices argument name. This change is
identical except it also adds staging for the argument name change.
We're switching to the new terminology to avoid confusion in cases
where multiple jax processes are running on a single host, and each
process has a unique process_index/host_id.
This keeps aliases for the old `host_id` APIs for now, but these will
eventually be removed.
We make sure that both the inputs and the outputs of
callbacks can contain empty arrays.
Most platforms do not support empty infeed, so we ensure
we do not send those.
* Moved CHANGELOG to docs
This puts the documentation also on RTD, with TOC.
Also changed its format to .rst, for consistency.
Added GitHub links to the change log.
* Actually add the CHANGELOG.rst
* Added reminder comments to the CHANGELOG.rst