189 Commits

Author SHA1 Message Date
Matthew Johnson
b81c246a18
move the trace liveness check from #4312 (#4315) 2020-09-16 23:59:58 -07:00
Matthew Johnson
325d3bc71d
improve an escaped tracer error message (#4312)
* improve an escaped tracer error message

Before this commit, encountering an escaped tracer in a specific way
would lead to a bad internal error. This change
1. raises an UnexpectedTracerError instead, and
2. includes in the error message the user source line which created the
tracer.

* deflake

* replace _live propety with _assert_live method

Thanks @jekbradbury !
2020-09-16 15:59:50 -07:00
Matthew Johnson
2678a4647a
omnistaging on by default (#4038) 2020-09-15 08:06:46 -07:00
Adam Paszke
40fb01b4bd Extend axis env while translating the pmapped jaxpr to XLA
This is normally unnecessary, because the XLA translation usually
doesn't bind any of the primitives in the jaxpr, but this is not true in
case of scan! Its translation rule reevaluates the jaxpr as a function,
and if it contains collectives such as `axis_index` it can fail due to
axis being missing.
2020-09-11 17:56:32 +02:00
Adam Paszke
0aed1f4ddf Add more context to the axis_frame error message.
Some of the vmap and gmap collective tests have been failing on master
and I can't seem to be able to reproduce them locally. Hopefully, if
this happens again, this extra bit of information will be useful in
debugging the problem.
2020-09-07 16:25:30 +02:00
George Necula
634c6259df
More renaming of master to main in JAX internals (#4179) 2020-08-30 12:38:14 +03:00
Matthew Johnson
6b6789a53b
applied simple find+sed for 'master' -> 'main' (#4174)
* applied simple find+sed for 'master' -> 'main'

* Rename master->main in JAX API and internals (#4178)

* Started with #4174 
* Renamed Trace.master to Trace.main
* Renamed core.new_master and core.new_base_master

Co-authored-by: George Necula <gcnecula@gmail.com>
2020-08-30 11:16:51 +03:00
Adam Paszke
a33f4dd8c8
Add support for axis_index inside vmap (#4168)
Also, reorganize the code to put all `axis_index` related functions in
`lax_parallel.py`, next to all other parallel collectives.
2020-08-28 20:03:39 +02:00
Adam Paszke
7210d6f5d0 Add support for binding axis_name in gmap
This allows executing collectives over the gmapped axes. This requires
some extra manipulation of the gmapped jaxpr, since gmap exposes a
single logical axis name, but evaluates the program using multiple
"physical" axes.

This also fixes some bugs around handling `multiple_returns` in
vmap collective implementation.
2020-08-28 14:42:01 +02:00
Tom Hennigan
f0fb7d0925
Use omnistaging env var even when not using absl flags for config. (#4152) 2020-08-26 14:06:27 -07:00
Sharad Vikram
774b5f688e Remove frame check assertion in extend_axis_env. 2020-08-24 21:13:30 -07:00
Matthew Johnson
66a02b6971 only construct one axis_index_p primitive
Before this change, there were two versions, one used with omnistaging
and one without. But that made bookkeeping hard and buggy. This change
defines the axis_index_p primitive in core.py. Some of its rules are
still changed when omnistaging is enabled.
2020-08-21 17:43:15 -07:00
Matthew Johnson
b89223629e
remove check for TypedJaxpr literals arent tracers (#4096)
In the original usage of TypedJaxpr, literals could not be tracers
because they were only produced by initial-style transformations of
jaxprs. But now TypedJaxpr is used in several other ways, e.g. in
make_jaxpr, and moreover its avals are redundant. It should probably be
renamed ClosedJaxpr since it mainly serves to package a jaxpr together
with its constant arrays. This check was limiting the utility of
TypedJaxpr, and it was only added relatively recently anyway.
2020-08-18 21:04:14 -07:00
Stephan Hoyer
decd760020
Add experimental __array_module__ method (#4076)
* Add experimental __array_module__ method

xref https://github.com/google/jax/issues/1565

`__array_module__` (see [NEP 37](https://numpy.org/neps/nep-0037-array-module.html))
is an experimental alternative to `__array_function__` and `__array_ufunc__`
for "duck array" compatibility with NumPy that promises to be much less
invasive.

Example usage:

```python
import numpy as np

def duckarray_stack(arrays):
    """This "stack" function should work with any array library, including JAX."""
    npx = np.get_array_module(*arrays)
    arrays = [npx.asarray(arr) for arr in arrays]
    shapes = {arr.shape for arr in arrays}
    if len(shapes) != 1:
        raise ValueError('all input arrays must have the same shape')
    expanded_arrays = [arr[npx.newaxis, ...] for arr in arrays]
    return npx.concatenate(expanded_arrays, axis=0)
```

Support for this protocol has *not* yet been implemented in NumPy, but it can
be tested with https://github.com/seberg/numpy-dispatch.

My reasoning for merging it into JAX (on an experimental basis with no
guarantees, of course) is that:

1. It's not invasive -- the implementation is small and self-contained.
2. No backwards compatibility issues. Unlike `__array_function__` and
   `__array_ufunc__`, `__array_module__` will always require an explicit
   opt-in by libraries that use it by calling `get_array_module()`.
2. Other NumPy developers
   [want evidence](https://github.com/numpy/numpy/pull/16935#issuecomment-673951287)
   that this is actually feasible.
3. Scikit-Learn developers like @thomasjpfan are interested in exploring
   supporting scikit-learn on top of NumPy-like libraries like JAX, and
   experimental support for this protocol will make that easier.

Note: this PR does add `numpy-dispatch` as a optional testing requirement in
order to verify that this works. If desired, we could remove this from CI, but
installing numpy-dispatch (and its build requirement Cython) appears to only
add a few seconds of build time.

* don't explicitly list cython

* remove UnshpaedArray from _JAX_ARRAY_TYPES

* Remove incorrect note about metaclasses

* remove unnecessary numpy_dispatch.ensure_dispatching()
2020-08-18 09:40:57 -07:00
Jake Vanderplas
8923bab50d
fixes for pytype (#4068) 2020-08-14 12:53:02 -07:00
Jake Vanderplas
c311fb77fb
Make it possible to override raise_to_shaped for new types (#4064) 2020-08-14 11:51:19 -07:00
Adam Paszke
b75bae6437
Initial version of vmap collectives (#4005)
This adds support for the basic (associative and commutative)
collectives to vmap. Supporting more complex collectives will
require some more complicated rules. Also, at the moment it is not
possible to use collectives inside `custom_vjp` rules which we might
want to fix in the future.

This feature is also omnistaging-only.

Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-08-14 18:22:04 +02:00
Julius Kunze
2e873b7d05
Fix jnp.right_shift incorrect on unsigned ints (#3958) 2020-08-05 18:36:46 -07:00
Matthew Johnson
ff96de935b
add dummy eval context (#3932) 2020-07-31 22:20:58 -07:00
Roy Frostig
cd64d2eed5 typecheck scan and cond params 2020-07-31 15:58:13 -07:00
Matthew Johnson
4236eb2b59
omnistaging, under a flag and disabled by default (#3370)
This change, when enabled, stages out all primitive calls in the dynamic
scope of a jitted, pmapped, or control flow function, rather than only
staging out based on data dependence. One improvement is that jitted
functions can consume less memory, by avoiding instantiating large
constants at trace time, and cause less memory fragmentation as well. It
also simplifies several internals.

See https://github.com/google/jax/pull/3370 fo more information.
2020-07-30 12:59:36 -07:00
Matthew Johnson
c9d8acd2e9
put core trace state in a threading.local class (#3869)
this is a refinement of the fix in #3845, so that we no longer need
TraceState.set_state (and so that #3370 is easier to adapt)
2020-07-26 22:38:14 -07:00
Peter Hawkins
53a4538129
Fix source_info crash in Jaxpr printing (#3849) 2020-07-24 11:52:32 -04:00
Matthew Johnson
cc9528d97d
fix thread locality bug in custom_derivatives (#3845)
* fix thread locality bug in custom_derivatives

fixes #3843
2020-07-23 19:49:04 -07:00
Jake Vanderplas
a7c2cdea64
Cleanup: convert uses of import numpy as onp in library code (#3754) 2020-07-14 13:05:31 -07:00
Matthew Johnson
49cfe2687c
improve concreteness error message for nn.one_hot (#3656)
* improve nn.one_hot and jax.numpy.arange errors

fixes #3654

* deflake

* debug
2020-07-03 20:54:25 -07:00
Roy Frostig
101581d872 add source info to jaxpr typechecking messages 2020-06-29 14:06:31 -07:00
Roy Frostig
c7e97b79f4 limit jaxpr context in typechecker error messages 2020-06-26 15:31:04 -07:00
Roy Frostig
c54acbf903 introduce custom typecheck rules, implement them for cond and scan 2020-06-26 15:31:04 -07:00
Roy Frostig
6b3b42d9c5 raise a custom error in jaxpr checker 2020-06-26 15:31:04 -07:00
Matthew Johnson
a45e28377f
add back a full_lower, dropped in #3491 (#3530) 2020-06-23 12:08:12 -07:00
Matthew Johnson
75278309aa
refactor call primitives, simpler param processing (#3491) 2020-06-23 09:39:45 -07:00
Peter Hawkins
3290e16a9a
Attach source info to Jaxpr equations. (#3421)
* Attach source info to Jaxpr equations.

Example:
```
In [1]: import jax, jax.numpy as jnp
In [2]: def f(x, y):
   ...:    z = jax.numpy.cos(x)
   ...:    z = z * jax.numpy.tanh(y)
   ...:    return z + 2
   ...:

In [3]: jax.make_jaxpr(jax.value_and_grad(f))(7., 9.)
Out[3]:
{ lambda  ; a b.
  let c = cos a  [<ipython-input-2-5d59f71cb65d>:2 (f)]
      d = tanh b  [<ipython-input-2-5d59f71cb65d>:3 (f)]
      e = mul c d  [<ipython-input-2-5d59f71cb65d>:3 (f)]
      f = add e 2.0  [<ipython-input-2-5d59f71cb65d>:4 (f)]
      g = mul 1.0 d  [<ipython-input-2-5d59f71cb65d>:3 (f)]
      h = neg g  [<ipython-input-2-5d59f71cb65d>:2 (f)]
      i = sin a  [<ipython-input-2-5d59f71cb65d>:2 (f)]
      j = mul h i  [<ipython-input-2-5d59f71cb65d>:2 (f)]
  in (f, j) }

In [7]: print(jax.xla_computation(jax.value_and_grad(f))(7., 9.).as_hlo_module().to_string())
HloModule xla_computation_f__4.15

ENTRY %xla_computation_f__4.15 (parameter.1: f32[], parameter.2: f32[]) -> (f32[], f32[]) {
  %constant.3 = pred[] constant(false)
  %parameter.1 = f32[] parameter(0)
  %cosine.4 = f32[] cosine(f32[] %parameter.1), metadata={op_type="cos" op_name="xla_computation(f)/cos" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
  %parameter.2 = f32[] parameter(1)
  %tanh.5 = f32[] tanh(f32[] %parameter.2), metadata={op_type="tanh" op_name="xla_computation(f)/tanh" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
  %multiply.6 = f32[] multiply(f32[] %cosine.4, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
  %constant.7 = f32[] constant(2), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
  %add.8 = f32[] add(f32[] %multiply.6, f32[] %constant.7), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4}
  %constant.9 = f32[] constant(1), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
  %multiply.10 = f32[] multiply(f32[] %constant.9, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3}
  %negate.11 = f32[] negate(f32[] %multiply.10), metadata={op_type="neg" op_name="xla_computation(f)/neg" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
  %sine.12 = f32[] sine(f32[] %parameter.1), metadata={op_type="sin" op_name="xla_computation(f)/sin" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
  %multiply.13 = f32[] multiply(f32[] %negate.11, f32[] %sine.12), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2}
  ROOT %tuple.14 = (f32[], f32[]) tuple(f32[] %add.8, f32[] %multiply.13)
}
```

Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-06-17 16:35:36 -07:00
Roy Frostig
a70ba920fe jaxpr pretty-print: wrap equation RHS when the LHS is long 2020-06-16 13:49:13 -07:00
Roy Frostig
15bc62204e jaxpr: support dropped assignment 2020-06-09 13:47:17 -07:00
Jake Vanderplas
2a10dbbf37
deflake remainder of jax (#3343) 2020-06-06 10:51:34 -07:00
Roy Frostig
dc4c9f0450 change cond primitive to an indexed conditional with multiple branch functions
in the core:

* bind and check cond primitive in indexed form
* rewrite abstract evaluation rule
* rewrite translation rule
* rewrite partial evaluation rule
* rewrite batching rule
* rewrite JVP rule
* rewrite transpose rule
* update jaxpr typechecker
* update pretty printer
* update outfeed-usage check
* update reference jaxpr in cond jaxpr test
* update reference regexes in HLO test

in experimental modules:

* update host_callback rewriter
* update loops expression builder
* generalize tf_impl rule
2020-06-03 22:19:15 -07:00
Matthew Johnson
177e7cf311
moved check_jaxpr code around to match eval_jaxpr (#3240)
* moved check_jaxpr code around to match eval_jaxpr

This change is mostly stylistic; it brings check_jaxpr closer to
eval_jaxpr (and the other jaxpr interpreters) in organization. There's a
slight tweak to an error message which lets us save some slightly
redundant code.

* fixes and tweaks
2020-06-02 19:10:55 -07:00
Peter Hawkins
042df4ebff
Fix pytype errors. (#3291) 2020-06-02 10:26:43 -04:00
Peter Hawkins
34065df248
Add some type annotations to core and partial_eval. (#3251) 2020-06-01 21:45:36 -04:00
Matthew Johnson
49a441f745
revisions to #3197 (#3264)
revert find_top_trace change from #3197

The previous version was written and tested for performance; the revised
version caused at least a 25% slowdown in the dispatch time of
`lax.add(1, 2)` (and so likely a much bigger slowdown for the
find_top_trace timing alone).

Instead, we can just change the error message in xla.abstractify, since
invalid types lead to abstractification errors when we apply primitive
impls.
2020-06-01 13:24:40 -07:00
Tom Hennigan
6124f703af
Add support for buffer donation in jit and pmap. (#2936)
For a computation of the form:

    >>> f = lambda x: x ** 2
    >>> f = jax.jit(f)
    >>> while run:
    ...   x = f(x)

JAX must currently always have two copies of `x` in device memory since there
is no reliable way in Python to determine whether there will be future uses of
`x`. This causes two classes of problem:

  1. Users at the limit of available device are constrained by the additional
     copy of their parameters and other state while they typically only require
     one copy. This typically frees 100M+ of device memory and is a critical
     optimization for larger models to match state of the art performance in
     other frameworks.

  2. This constant alloc/free of the input/output buffers can cause memory
     fragmentation on some platforms (although having a reusing allocator and
     limiting run-ahead may be a better solution for this problem).

We propose fixing this by using input/output aliasing as supported by XLA. We
will support this in JAX by allowing certain arguments of jit/pmap decorated
functions to be donated and reused as outputs:

    >>> f = lambda x: x ** 2
    >>> f = jit(f, donate_argnums=0)
    >>> while run:
    ...   x = f(x)

JAX will determine that the donated input `x` can alias with the output of the
function and it will instruct XLA it _must_ write the result to this buffer.

If a user tries to reuse a buffer after it has been donated they get an error
that the buffer is invalid:

    >>> y = f(x)
    >>> jax.device_get(x)
    ...
    RuntimeError: Invalid argument: CopyToHostAsync() called on invalid buffer.

The semantics of `donate_argnums` follows that of `static_argnums`, namely that
it identifies positional arguments to the computation that are to be donated
to the computation and used as part of the output.

One feature that is also enabled by this is invalidating buffers that should
only be used once, for example PRNGKeys:

    >>> @partial(jit, donate_argnums=0)
    ... def move(x):
    ...   # Do something complex enough for JAX to just optimize it away.
    ...   return tree_map(lambda x: x + x - x, x)

    >>> def safe_eager_uniform(key, *a, **k):
    ...   assert hasattr(key, 'device_buffer'), "random must run eagerly"
    ...   key = move(key)
    ...   return jax.random.uniform(key, *a, **k)

This is not a complete answer to random safety since it is still possible to
reuse a key as part of a traced computation, however it can be used to support
this feature (somewhat inefficiently) in eager mode.
2020-05-31 15:00:16 -07:00
Roy Frostig
e80e9634a7 jaxpr-dependent gensym to avoid var duplication 2020-05-27 12:03:34 -07:00
George Necula
f1ae2166d0
Added argument check to all primitives. (#3197)
* Added argument check to all primitives.

The issue that inspired this is that `lax.tie_in` is
easy to misuse if the first argument is not a JAX type, then
it silently disappears. This means that `lax.tie_in((x, x), const)`
is the same as `const` even though `x` is a tracer.

This error would be caught previously if core.skip_checks == False
because then `bind` checks its arguments. I have essentially added
an unconditional argument check to `bind`.

In case this is considered too inefficient, we can add argument
checking to individual primivites, e.g., tie_in. For most primitives
if a non-JAX array is passed, the `impl` rule would fire and `numpy`
would report the error somehow, perhaps.

* Merged find_top_trace with check_args

This was previously merged as #2948 but reverted awaiting the fixes
in some user code.
2020-05-24 19:12:37 +03:00
Roy Frostig
c293a102b2 work around mypy 2020-05-21 20:54:02 -07:00
Roy Frostig
69d7bcf7fb except-and-raise during jaxpr checking, adding jaxpr as context, and simplify type environment 2020-05-21 20:02:30 -07:00
Roy Frostig
8e61ce8d1a fix unitvar comparisons and move to class attributes 2020-05-21 18:28:09 -07:00
Roy Frostig
7ff389bd03 extend type transfer to all primitives, including call and map primitives 2020-05-21 13:21:07 -07:00
Roy Frostig
e2cc568997 raise type errors consistently in jaxpr checker 2020-05-21 13:21:07 -07:00
Roy Frostig
1e55603344 avoid attempt to read literals from the typechecking environment 2020-05-21 13:21:07 -07:00