We frequently use the pattern
try:
import m
except ImportError:
# do something else.
This suppresses errors when the module can be found but does not import
successfully for any reason. Instead, catch only ModuleNotFoundError so
missing modules are allowed but buggy modules still report errors.
This is much more convenient and lets us register callbacks that trigger on
changes. I want to add more toggles (e.g. for the SPMD lowering that restricts
sharding of every intermediate), so I want to work out a reasonable approach to
do that first.
Second attempt, this time without hardening against the flags being
registered too late due to delayed imports.
This is much more convenient and lets us register callbacks that trigger on
changes. I want to add more toggles (e.g. for the SPMD lowering that restricts
sharding of every intermediate), so I want to work out a reasonable approach to
do that first.
PiperOrigin-RevId: 384902895
This is much more convenient and lets us register callbacks that trigger on
changes. I want to add more toggles (e.g. for the SPMD lowering that restricts
sharding of every intermediate), so I want to work out a reasonable approach to
do that first.
PiperOrigin-RevId: 384892199
Previously we simply converted integer_pow to tf.math.pow. JAX instead uses
a series of multiplications. We now use the same lowering strategy as JAX, so
that we have the same numerical result.
Also improved the error messages for assertion failures.
PiperOrigin-RevId: 373351147
--
1ecf4f02891cad70cc8f094b49cf2458105ca366 by George Necula <gcnecula@gmail.com>:
[jax2tf] Change the conversion of dot_general to use XLA op.
Instead of converting the dot_general to a sea of TF ops, when
we enable_xla we just use the XLA op. This has the advantage
that it also supports the preferred_element_type.
Fixed bug with passing the precision parameter to TF.
Also improved tests to print the HLO in case of numerical errors.
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/6717 from gnecula:tf_dot 1ecf4f02891cad70cc8f094b49cf2458105ca366
PiperOrigin-RevId: 373326655
Mark the index_update() etc. operators as deprecated in the documentation.
Add new .divide and .power operators. Fixes#2694.
Add .multiply as an alias for .mul. To be more numpy-like we should probably prefer the longer names.
There are a few test cases that generate millions of configurations,
only to have a handful of them selected by `cases_form_list`. I've
found all tests that spend over 100ms in case generation and
converted them to a new "test sampler" approach. The result: test
generation time drops from 15s to around 2s. Doesn't sound like much,
but I expect that we all run tests many times daily, so it seems like a
useful thing to have.
The rough idea is that the sampling generators get parameterized by a
sampler function that should be applied to the range of every `for` loop.
This allows us to sample runs of the generator through different
configurations by restricting each loop to a smaller subset. Right now
we always narrow it down to a single randomly selected instance. But,
we still retain the possibility of adding exhaustive testing in the
future, which can be achieved by passing in an identity sampling
function that wouldn't modify any loop ranges.
Rather than merely reporting a failure in check_grads(), we now
report the *specific* check that failed, e.g., "JVP tangent" or
"VJP of JVP cotangent projection". Gradient tests often fail for
spurious reasons (e.g., due to insufficient precision), so this should
be helpful for debugging.
I tested this manually by relaxing the tolerance for a test in
`linalg_test.py`.
Also, pass the body to XLA JIT when no parallel resources are used.
There is no reason to not do that given that we already require users to
pay the price of making their code jittable.
* Adds support for jit of pmap and pmap of pmap.
* Also adds a `tap_with_device` optional argument to `id_print` and
`id_tap`, to have the tap function invoked with a device keyword argument.
* Added multiple tests involving pmap
Issue: #5134Fixes: #5169
AD didn't use `HashableFunction` enough, tripping up the compilation
cache. I've also used the occasion to make function hashing a little
safer by including the Python bytecode of the wrapped function as part
of the key.