564 Commits

Author SHA1 Message Date
Jean-Baptiste Lespiau
17f11e05e0 Add accessors on Compiled returning the args and kwargs PyTreeDef working for all transforms.
This also documents the fact that `in_tree` content varies, based on the transform.

PiperOrigin-RevId: 432895923
2022-03-07 02:36:42 -08:00
Roy Frostig
947b7b88e1 re-implement custom_transpose without upfront staging.
Whereas the previous `custom_transpose` implementation would stage its
callable arguments upfront, this one preserves them as callables. For
the time being, this requires callers to additionally supply the target
function's output types at call time.

Co-authored-by: Matthew Johnson <mattjj@google.com>
2022-03-04 16:50:51 -08:00
Peter Hawkins
c978df5550 Increase minimum jaxlib version to 0.3.0. 2022-03-04 10:33:03 -05:00
Roy Frostig
d636e74626 make xla_executable a property, consistent across executable types
Also test IR and executable-related methods of `Lowered` and
`Compiled`.
2022-02-25 19:05:44 -08:00
Sharad Vikram
1b79caa6bd Add separate mechanism for threading name stacks to the lowering 2022-02-23 09:59:09 -08:00
Parker Schuh
662c4416a3
Merge branch 'main' into opt-barrier 2022-02-15 14:16:20 -08:00
Peter Hawkins
b0b8f037b0 [JAX] Fix crash when applying jit() to a callable that is not weak-referenceable.
Fixes https://github.com/google/jax/issues/9541

PiperOrigin-RevId: 428829999
2022-02-15 11:18:05 -08:00
jax authors
f229a703e7 Merge pull request #9562 from jakevdp:disable-rank-promotion
PiperOrigin-RevId: 428579739
2022-02-14 12:27:22 -08:00
Parker Schuh
7ce911b8d1 Add translation rule for optimization barrier.
Also adds a translation rule for remat that uses the new optimization barrier
op. If you find errors, consider disabling the remat lowering using
`jax_remat_opt_barrier` config flag.
2022-02-14 12:21:16 -08:00
Jake VanderPlas
97512e9e44 JaxTestCase: set jax_numpy_rank_promotion='raise' by default 2022-02-14 09:22:05 -08:00
Peter Hawkins
5a259925a0 Add constant handler for tokens.
Fixes https://github.com/google/jax/issues/9438
2022-02-14 12:09:29 -05:00
Jake VanderPlas
4f6004a3c9 JaxTestCase now sets jax_numpy_rank_promotion='raise' by default
PiperOrigin-RevId: 428489444
2022-02-14 06:20:42 -08:00
jax authors
5691010d2f Copybara import of the project:
--
d42fffd849a4bac0c0c11a3346c93f07f8c64c44 by Jake VanderPlas <jakevdp@google.com>:

JaxTestCase: set numpy_rank_promotion='raise' by default
PiperOrigin-RevId: 427896974
2022-02-10 19:08:29 -08:00
Jake VanderPlas
6324577a63 JaxTestCase: set numpy_rank_promotion='raise' by default 2022-02-10 16:54:31 -08:00
Lena Martens
1340fbbc09 Strip named_shape and weak_type from aval when donating buffers.
PiperOrigin-RevId: 427744848
2022-02-10 07:39:50 -08:00
jax authors
fe14530347 Merge pull request #9391 from jakevdp:fix-constant-handler
PiperOrigin-RevId: 425677978
2022-02-01 11:44:09 -08:00
Matthew Johnson
d9dcd1394a djax: let make_jaxpr build dyn shape jaxprs 2022-02-01 00:10:21 -08:00
Jake VanderPlas
37e73fce7f Add complex types to mlir constant handlers 2022-01-31 10:56:52 -08:00
Yash Katariya
f3ae2c0dfe Strip named_shape and weak_type from aval when donating buffers.
PiperOrigin-RevId: 424968695
2022-01-28 15:16:55 -08:00
Tom Hennigan
ace8c0a53a Strip named_shape and weak_type from aval when donating buffers.
PiperOrigin-RevId: 424888671
2022-01-28 09:35:54 -08:00
jax authors
e6f9ba0a14 Merge pull request #9275 from froystig:auto-vmap
PiperOrigin-RevId: 424765479
2022-01-27 19:38:31 -08:00
jax authors
ea4043a5cc Merge pull request #9325 from jakevdp:fix-rank-promotion-test
PiperOrigin-RevId: 424195563
2022-01-25 15:33:19 -08:00
Jake VanderPlas
4c3473dd74 Make forces_retrace tests more robust 2022-01-25 14:53:40 -08:00
Peter Hawkins
6bda1e5dd8 [JAX] Require exact type equality using is for static arguments.
Fixes https://github.com/google/jax/issues/9273.

PiperOrigin-RevId: 424182826
2022-01-25 14:36:02 -08:00
Jake VanderPlas
3197aacbfc disable implicit rank promotion for api_test 2022-01-24 13:44:04 -08:00
jax authors
21ddd83615 Merge pull request #8420 from Huizerd:dev/fwd
PiperOrigin-RevId: 423232626
2022-01-20 21:45:53 -08:00
Roy Frostig
30c3a39467 introduce custom_batching.sequential_vmap
An anticipated common use of `custom_vmap` is in order to implement a
map via loop (i.e. to sequentially apply the mapped function), instead
of actually vectorizing.
2022-01-20 21:37:58 -08:00
Huizerd
d05431a1ff has_aux for jvp and forward-mode AD value_and_grad
Changes:
- revert value_and_grad_fwd
- add has_aux to jacfwd and jacrev
- tests

fix mypy error
2022-01-20 11:13:58 +01:00
Peter Hawkins
3fef74b2d0 [JAX] Change signature of .mhlo() method on compiler IR objects to return an ir.Module object instead of its string representation.
It isn't free to pretty-print IR, so it's best to avoid it unless necessary. In addition, by returning an IR object, the user is now free to, say, print it with different options.

For example, one can now write things like:

```
In [1]: import numpy as np, jax, jax.numpy as jnp
In [2]: m = jax.jit(lambda x: x + jnp.array(np.arange(1000))).lower(7.).compiler_ir(dialect='mhlo')
In [3]: m.operation.print(large_elements_limit=10)
module @jit__lambda_.4 {
  func public @main(%arg0: tensor<f32>) -> tensor<1000xf32> {
    %0 = mhlo.constant opaque<"_", "0xDEADBEEF"> : tensor<1000xi32>
    %1 = "mhlo.convert"(%0) : (tensor<1000xi32>) -> tensor<1000xf32>
    %2 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<1000xf32>
    %3 = mhlo.add %2, %1 : tensor<1000xf32>
    return %3 : tensor<1000xf32>
  }
}
```

Fixes https://github.com/google/jax/issues/9226

PiperOrigin-RevId: 422855649
2022-01-19 11:04:48 -08:00
Roy Frostig
a855890d46 custom vmap: abstract eval and translation rules
Also fix and test a tree-flattening bug in the custom_vmap batching
rule.
2022-01-18 15:48:29 -08:00
jax authors
436ce7904c Merge pull request #9175 from froystig:custom-xform-wrappers-forward-attrs
PiperOrigin-RevId: 421449851
2022-01-12 19:11:22 -08:00
Matthew Johnson
08aec823fd fix a custom_vjp post_process bug, related cleanups
related to #8783, doesn't completely fix it
2022-01-12 07:51:50 -08:00
Roy Frostig
ddc1c3e9bd enable custom transformation "stacking"
Make custom transformation wrappers such as `custom_jvp` behave
interchangeably when directly composed. For example, enable the
following usage:

```
@jax.custom_jvp
@jax.custom_transpose
def f(x): ...

@f.def_transpose
def f_t(y): ...

@f.defjvp
def f_jvp(x, tx): ...
```

In particular:

* Forward `def*` methods on custom transformations.

* Have unary `def*` methods return their argument so that, when used
  as decorators, they do not replace their target with `None`.

* Fix a bug in the use of `functools.update_wrapper`: previously a
  wrapper would overwrite its own attributes with those of the target
  callable (including its reference to the target callable).
2022-01-11 17:55:08 -08:00
Roy Frostig
1709e06800 introduce custom_transpose and a corresponding primitive
Includes rules for impl, transpose, abstract eval, and xla/mlir
translation.
2022-01-11 12:51:17 -08:00
Matthew Johnson
3a2fb1844c Fix exception handling logic in C++ dispatch code.
The dispatch code was always raising its own exception when an exception
occurred during hashing of static arguments, even if the exception which
occurred was something like a KeyboardInterrupt.

fixes #9082

PiperOrigin-RevId: 420292886
2022-01-07 07:59:29 -08:00
Roy Frostig
ad7c7d6eab custom batching jvp tests 2022-01-05 18:07:20 -08:00
Roy Frostig
0ab93a039e custom batching vmap tests 2022-01-05 18:07:20 -08:00
Matthew Johnson
9ee4b92f17 clean up WrappedFun.call_wrapped refs on exception
Functions decorated by linear_util.transformation or
transformation_with_aux are coroutines (with two yields). They can raise
exceptions, either before or after they yield the first time.

linear_util.WrappedFun.call_wrapped, which is responsible for driving
these coroutines, holds references to them.

These coroutines often manipulate global trace state (i.e.
core.thread_local_state.trace_state attributes) through context managers
(e.g. core.new_main or core.extend_axis_env). These context managers use
try/finally to clean up their state changes.

When an exception is raised in a linear_util.transformation coroutine,
it is raised into call_wrapped. If call_wrapped doesn't then clean up
all the references it has to coroutines, the cleanup finally clauses may
not execute until too late.

To ensure the finally clauses are called at the right time (before
exiting call_wrapped, basically as soon as possible) we need to clean up
the references to the coroutines in call_wrapped.

We had cleaned up these coroutine references when the coroutines raised
exceptions in their first part (i.e. before their first yield) in #4226.
But we didn't do a similar cleanup for their second part (i.e. after
their first yield and before their second).

Co-authored-by: Roy Frostig <frostig@google.com>
2022-01-05 16:11:48 -08:00
jax authors
113cd9b939 Merge pull request #8947 from mattjj:issue8910
PiperOrigin-RevId: 417869468
2021-12-22 12:41:45 -08:00
Matthew Johnson
a582fa8748 add limit to number of tracer provenance lines
fixes #8910
2021-12-22 12:18:47 -08:00
Jake VanderPlas
d2908af8de Add item() method to abstract arrays 2021-12-15 16:22:26 -08:00
Peter Hawkins
0c169764ed Use .__mro__ instead of .mro() when enumerating superclasses of a type.
mro() has a different signature on metaclasses, but __mro__ is a cached tuple property that appears to have the same signature everywhere. As far as I can tell, it always exists.

PiperOrigin-RevId: 416410647
2021-12-14 15:36:25 -08:00
Matthew Johnson
c8a34fe5cc add jax.block_until_ready function
fixes #8536
2021-12-14 11:02:14 -08:00
Matthew Johnson
ed365636bf add simple test 2021-12-13 22:11:38 -08:00
Peter Hawkins
add967db88 [JAX] Add a dialect option to jit(...).lower(...).compiler_ir().
The dialect allows the user to select between HLO and MHLO output.

PiperOrigin-RevId: 415591372
2021-12-10 13:02:25 -08:00
Jake VanderPlas
df0969961b Testing: avoid hard-coding random seeds 2021-12-10 10:32:09 -08:00
Roy Frostig
b980acf375 detect and err on transformation of AOT-compiled function calls 2021-12-07 17:20:27 -08:00
Peter Hawkins
06cd1fedee Move dtype canonicalization out of core.AbstractValue subclasses.
This is a strictly mechanical change that moves abstract value canonicalization out of the core.AbstractValue subclasses and into their callers. This makes it safe to manipulate non-canonical abstract values even inside an -x32 context.

The callers to which canonicalization was added were:
a) all callers of `ConcreteArray` inside the JAX Tree.
b) all callers of `ShapedArray` and `UnshapedArray` that were found to be passing non-canonical dtypes during a global presubmit. These were identified by adding an assertion that the dtype is in fact canonical and fixing all the resulting test failures.

PiperOrigin-RevId: 414704700
2021-12-07 06:13:07 -08:00
Peter Hawkins
68e9e1c26d Consolidate more XLA-lowering logic between jit, pmap, and xmap.
Move remaining functions relating to building XLA HLO IR out of xla_bridge.py and into jax.interpreters.xla.

PiperOrigin-RevId: 413244450
2021-11-30 14:24:33 -08:00
Peter Hawkins
db3c3aae87 [JAX] Correctly propagate Python errors out of pytree code when the keys of an enum value cannot be sorted.
Also catch std::runtime_error since the pytree code may throw it.

PiperOrigin-RevId: 413160923
2021-11-30 08:50:30 -08:00