22 Commits

Author SHA1 Message Date
Jake VanderPlas
6df80549ef Some minor typing fixes 2021-03-02 09:42:52 -08:00
Peter Hawkins
3ac809ede3 [JAX] Move jax.util to jax._src_util.
PiperOrigin-RevId: 351234602
2021-01-11 14:21:07 -08:00
George Necula
a5fbc2865d Major refactoring of the jax2tf test harnesses.
See the PR description for details.
2020-12-27 17:56:54 +02:00
Peter Hawkins
81b6cd29ff [JAX] Move traceback_util.py into jax._src.
traceback_util is a JAX-private API.

PiperOrigin-RevId: 340659195
2020-11-04 09:02:59 -08:00
Roy Frostig
5d50e19364 add path exclusion opt-in to filtered stack traces and use it throughout the codebase 2020-10-26 12:31:19 -07:00
Lena Martens
ecad419cf3 Support grad with integer arguments.
- Add float0 and set-up at_least_vspace to return float0
values for int/bool primals
- Use Zero to wrap float0 tangents so they're correctly ignored in jvp
rules
- Add float0 handlers to XLA to support jit
- Fix convert_element_type and tie_in jvp rules
2020-09-28 19:07:04 +01:00
Matthew Johnson
4236eb2b59
omnistaging, under a flag and disabled by default (#3370)
This change, when enabled, stages out all primitive calls in the dynamic
scope of a jitted, pmapped, or control flow function, rather than only
staging out based on data dependence. One improvement is that jitted
functions can consume less memory, by avoiding instantiating large
constants at trace time, and cause less memory fragmentation as well. It
also simplifies several internals.

See https://github.com/google/jax/pull/3370 fo more information.
2020-07-30 12:59:36 -07:00
Adam Paszke
adb442eb8a Make ad_util.zero a class that carries avals (similar to UndefinedPrimal)
This is useful for remat transpose rule submitted in #3162 and e.g.
allowed me to catch a slight overuse of defjvp2 for `random_gamma_p` (it
was unnecessarily declared as having multiple outputs).
2020-06-05 15:51:30 +00:00
Matthew Johnson
13a17286df
stop_gradient_p -> ad_util.py, re-enable some mypy (#2806) 2020-04-23 13:12:24 -07:00
Peter Hawkins
68b32bf704
Add mypy type checking (#2430)
* Add type annotations to make mypy pass.

* Add mypy to .travis.yml.
2020-03-18 17:06:05 -04:00
Matthew Johnson
47df7b95c4
change the xla representation of JAX's unit (#2416)
* change the xla representation of JAX's unit

Previously the representation of JAX's unit value (a sentinel /
placeholder) was an empty tuple, but by changing the representation to
something else we can further reduce our dependence on runtime tuples.

This commit makes the representation fairly easy to change. There are
three functions in xla.py that define the representation. Here are
versions that would keep the old XLA representation as an empty tuple:

```
def _make_unit(c): return c.Tuple()
def _make_abstract_unit(_): return xc.Shape.tuple_shape(())
def _device_put_unit(_, device):
  return xc.Buffer.make_tuple((), device, backend=xb.get_device_backend(device))
```

The new representation is as a trivial array. An alternative
representation would be nothing at all: we don't need to generate XLA
computations that have representations of JAX units. While that
alterntaive is probably the best choice, it seemed like it would require
a bit more refactoring/bookkeeping (e.g. to allow XLA computations to
have a smaller number of outputs than the corresponding JAX function),
and would also mean the XLA representation would be a step further
removed from the jaxpr representation. So I stuck with a trivial array
for now.

The mapping from JAX types to XLA types need not be invertible. However,
XLA translation rules currently don't take as arguments the
corresponding JAX types (abstract values), and there were a few cases
where we relied on checking whether an argument's XLA type was that of
an empty tuple so as to determine if we were effectively operating on a
JAX unit.

In particular, the AD-related primitive add_jaxvals_p could in principle
add two units, and get lowered to an XLA addition on the unit
representation. Previously, the translation rule for add_jaxvals_p
checked the XLA type so that adding two empty tuples didn't produce any
XLA operation; now it adds its inputs, and so if unit is represented as
a trivial array we could be inserting trivial scalar adds where we had
none before. However, if that case is ever possible, it doesn't come up
in our tests (which I checked by keeping the representation as an empty
tuple and then asserting an XLA tuple type is never seen by that
translation rule).

* add comment about JAX<->XLA array types assumption
2020-03-14 12:33:14 -07:00
George Necula
c52f32b59d
Removed unused imports (#2385)
Also disabled a couple more linalg tests that crash on my Mac
2020-03-09 20:42:08 +01:00
Peter Hawkins
e60d5dd54c
Remove "from __future__" uses from JAX. (#2117)
The future (Python 3) has arrived; no need to request it explicitly.
2020-01-29 12:29:03 -05:00
Peter Hawkins
7dbc8dc1bc
Minimal changes to make Jax pass a pytype check. (#2024) 2020-01-18 08:26:23 -05:00
Matthew Johnson
b702f8de3e De-tuplify the rest of the core
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-08-21 13:21:20 -07:00
Dougal Maclaurin
6d71396d56 Start exploring jaxprs without tuples
Co-authored-by: Matthew Johnson <mattjj@google.com>
2019-08-21 07:01:07 -07:00
Matthew Johnson
4c2ec3e442 ship it
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-05-08 13:41:25 -07:00
Matthew Johnson
a17f8e4ca8 add jaxpr eqn structured input, transpose progress
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-05-08 13:41:19 -07:00
Matthew Johnson
4c1fc9cfbd peval.py works again (some paired w/ @dougalm) 2019-02-22 07:53:28 -08:00
Peter Hawkins
5e60639bc5 source sync
PiperOrigin-RevId: 222452709
2018-11-21 20:22:54 -08:00
Peter Hawkins
e180f08113 source sync
PiperOrigin-RevId: 222451919
2018-11-21 20:22:51 -08:00
Matthew Johnson
a30e858e59 populating source tree 2018-11-17 18:03:33 -08:00