12 Commits

Author SHA1 Message Date
Jake VanderPlas
4a6bbde409 Move jax.linear_util to jax._src.linear_util 2022-12-20 14:49:27 -08:00
Roy Frostig
d927a5dbf3 migrate internal dependencies from jax.core to jax._src.core
... in preparation for paring down `jax.core`'s exported symbols.

Also includes a few import fixups along the way, and a TODO comment to avoid an
import cycle in `_src/dtypes.py`.

PiperOrigin-RevId: 496024782
2022-12-16 21:00:14 -08:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Matthew Johnson
f680269a4f [dynamic-shapes] initial support for dynamic shape typechecks
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2022-06-17 14:57:19 -07:00
Peter Hawkins
a48752a578 [MHLO] Remove most XLA translation rules.
Almost all XLA translation rules have MHLO equivalents at this point, and there are no code paths that use the XLA translation rules in preference to their MLIR equivalents.

PiperOrigin-RevId: 442547482
2022-04-18 08:28:35 -07:00
Peter Hawkins
cb4abe754a [MHLO] Separate registrations for collective and initial_style primitives from the XLA translation rule registration.
Change in preparation for removing XLA translation rules for many primitives. However, even after the MHLO switch we still need to tag collective and initial_style primitives.

PiperOrigin-RevId: 441474701
2022-04-13 07:26:26 -07:00
Sharad Vikram
0fa1eddd25 Adds simple effect types to jaxprs 2022-04-11 11:50:41 -07:00
Roy Frostig
a6a43e2715 allow for recursive uses of custom_transpose
Co-authored-by: Matthew Johnson <mattjj@google.com>
2022-03-26 12:09:15 -07:00
Roy Frostig
45af307a61 staging and compilation for custom_transpose
Co-authored-by: Matthew Johnson <mattjj@google.com>
2022-03-16 18:50:00 -07:00
Roy Frostig
947b7b88e1 re-implement custom_transpose without upfront staging.
Whereas the previous `custom_transpose` implementation would stage its
callable arguments upfront, this one preserves them as callables. For
the time being, this requires callers to additionally supply the target
function's output types at call time.

Co-authored-by: Matthew Johnson <mattjj@google.com>
2022-03-04 16:50:51 -08:00
Roy Frostig
ddc1c3e9bd enable custom transformation "stacking"
Make custom transformation wrappers such as `custom_jvp` behave
interchangeably when directly composed. For example, enable the
following usage:

```
@jax.custom_jvp
@jax.custom_transpose
def f(x): ...

@f.def_transpose
def f_t(y): ...

@f.defjvp
def f_jvp(x, tx): ...
```

In particular:

* Forward `def*` methods on custom transformations.

* Have unary `def*` methods return their argument so that, when used
  as decorators, they do not replace their target with `None`.

* Fix a bug in the use of `functools.update_wrapper`: previously a
  wrapper would overwrite its own attributes with those of the target
  callable (including its reference to the target callable).
2022-01-11 17:55:08 -08:00
Roy Frostig
1709e06800 introduce custom_transpose and a corresponding primitive
Includes rules for impl, transpose, abstract eval, and xla/mlir
translation.
2022-01-11 12:51:17 -08:00