Whereas the previous `custom_transpose` implementation would stage its
callable arguments upfront, this one preserves them as callables. For
the time being, this requires callers to additionally supply the target
function's output types at call time.
Co-authored-by: Matthew Johnson <mattjj@google.com>
Also adds a translation rule for remat that uses the new optimization barrier
op. If you find errors, consider disabling the remat lowering using
`jax_remat_opt_barrier` config flag.
--
d42fffd849a4bac0c0c11a3346c93f07f8c64c44 by Jake VanderPlas <jakevdp@google.com>:
JaxTestCase: set numpy_rank_promotion='raise' by default
PiperOrigin-RevId: 427896974
An anticipated common use of `custom_vmap` is in order to implement a
map via loop (i.e. to sequentially apply the mapped function), instead
of actually vectorizing.
It isn't free to pretty-print IR, so it's best to avoid it unless necessary. In addition, by returning an IR object, the user is now free to, say, print it with different options.
For example, one can now write things like:
```
In [1]: import numpy as np, jax, jax.numpy as jnp
In [2]: m = jax.jit(lambda x: x + jnp.array(np.arange(1000))).lower(7.).compiler_ir(dialect='mhlo')
In [3]: m.operation.print(large_elements_limit=10)
module @jit__lambda_.4 {
func public @main(%arg0: tensor<f32>) -> tensor<1000xf32> {
%0 = mhlo.constant opaque<"_", "0xDEADBEEF"> : tensor<1000xi32>
%1 = "mhlo.convert"(%0) : (tensor<1000xi32>) -> tensor<1000xf32>
%2 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<1000xf32>
%3 = mhlo.add %2, %1 : tensor<1000xf32>
return %3 : tensor<1000xf32>
}
}
```
Fixes https://github.com/google/jax/issues/9226
PiperOrigin-RevId: 422855649
Make custom transformation wrappers such as `custom_jvp` behave
interchangeably when directly composed. For example, enable the
following usage:
```
@jax.custom_jvp
@jax.custom_transpose
def f(x): ...
@f.def_transpose
def f_t(y): ...
@f.defjvp
def f_jvp(x, tx): ...
```
In particular:
* Forward `def*` methods on custom transformations.
* Have unary `def*` methods return their argument so that, when used
as decorators, they do not replace their target with `None`.
* Fix a bug in the use of `functools.update_wrapper`: previously a
wrapper would overwrite its own attributes with those of the target
callable (including its reference to the target callable).
The dispatch code was always raising its own exception when an exception
occurred during hashing of static arguments, even if the exception which
occurred was something like a KeyboardInterrupt.
fixes#9082
PiperOrigin-RevId: 420292886
Functions decorated by linear_util.transformation or
transformation_with_aux are coroutines (with two yields). They can raise
exceptions, either before or after they yield the first time.
linear_util.WrappedFun.call_wrapped, which is responsible for driving
these coroutines, holds references to them.
These coroutines often manipulate global trace state (i.e.
core.thread_local_state.trace_state attributes) through context managers
(e.g. core.new_main or core.extend_axis_env). These context managers use
try/finally to clean up their state changes.
When an exception is raised in a linear_util.transformation coroutine,
it is raised into call_wrapped. If call_wrapped doesn't then clean up
all the references it has to coroutines, the cleanup finally clauses may
not execute until too late.
To ensure the finally clauses are called at the right time (before
exiting call_wrapped, basically as soon as possible) we need to clean up
the references to the coroutines in call_wrapped.
We had cleaned up these coroutine references when the coroutines raised
exceptions in their first part (i.e. before their first yield) in #4226.
But we didn't do a similar cleanup for their second part (i.e. after
their first yield and before their second).
Co-authored-by: Roy Frostig <frostig@google.com>
mro() has a different signature on metaclasses, but __mro__ is a cached tuple property that appears to have the same signature everywhere. As far as I can tell, it always exists.
PiperOrigin-RevId: 416410647
This is a strictly mechanical change that moves abstract value canonicalization out of the core.AbstractValue subclasses and into their callers. This makes it safe to manipulate non-canonical abstract values even inside an -x32 context.
The callers to which canonicalization was added were:
a) all callers of `ConcreteArray` inside the JAX Tree.
b) all callers of `ShapedArray` and `UnshapedArray` that were found to be passing non-canonical dtypes during a global presubmit. These were identified by adding an assertion that the dtype is in fact canonical and fixing all the resulting test failures.
PiperOrigin-RevId: 414704700