219 Commits

Author SHA1 Message Date
jax authors
d158647c83 Merge pull request #4706 from apaszke:vmap-collectives-in-scan
PiperOrigin-RevId: 339646941
2020-10-29 05:11:23 -07:00
Roy Frostig
5d50e19364 add path exclusion opt-in to filtered stack traces and use it throughout the codebase 2020-10-26 12:31:19 -07:00
Adam Paszke
6348a99fb4 Add support for vmap collectives in control flow primitives
All initial style primitives currently use `batch_jaxpr` in their
batching rules, but that function hasn't been updated to support
axis_name when I added support for vmap collectives.
2020-10-26 12:09:18 +00:00
George Necula
c5b983c0de
Update jax/core.py
Co-authored-by: Roy Frostig <froystig@users.noreply.github.com>
2020-10-21 07:53:37 +03:00
George Necula
cb591eb77f
Update jax/core.py
Co-authored-by: Roy Frostig <froystig@users.noreply.github.com>
2020-10-20 22:28:22 +03:00
George Necula
09653bf544 Ensure that check_jaxpr is done with abstract values
Prior to this it was possible, e.g., for code that contains a Literal,
such as  to result in FLOPS during checking.

The assertion is broken by many tests unless we raise_to_shape for Literals.

I have timed the checks on my laptop and I do not see a reduction in the
total test time.
2020-10-20 11:08:41 +03:00
Matthew Johnson
a493a0f43d ensure ConcreteArray equality stays in Python 2020-10-16 18:21:01 -07:00
jax authors
4a20eea828 Copybara import of the project:
--
609f6f3e16d21fed34cc5269c54a0d78ac44a8bc by Matthew Johnson <mattjj@google.com>:

fix custom_jvp/vjp closure issues

PiperOrigin-RevId: 337457689
2020-10-16 00:21:32 -07:00
Jean-Baptiste Lespiau
b13775f464 Enrich the error messages with the bound names that are available.
The user often do not know whether it's not the correct name, or whether it was not defined, etc. It's easier to get this information when debugging.
2020-10-12 20:33:27 +02:00
George Necula
0213efdf4d [jax2tf] Port jax2tf to use omnistaging
The main change is that we use `core.new_base_main` to use an
omnistaging-based tracer. This has the benefit that we can
convert to TF even functions with no arguments (previously
they would be constant-folded by JAX prior to the conversion).

We also add an explicit error if the jax2tf.convert transformation
is nested under other JAX transformations.
2020-10-09 18:42:28 +03:00
Matthew Johnson
52fe026c09 optimize scan partial_eval to fix #4510
fixes #4510
2020-10-08 20:34:34 -07:00
Jake VanderPlas
6393349783 raise_to_shaped: preserve weak_type by default 2020-10-08 11:53:52 -07:00
Roy Frostig
e7979258ee equation context for undefined var reads in jaxpr typechecker 2020-10-05 12:29:43 -07:00
Lena Martens
cc0114a0a9 Fix dtype behavior with float0s in CustomVJP. 2020-10-01 15:17:51 +01:00
jax authors
69fda9ecb9 Merge pull request #4039 from LenaMartens:changelist/325216264
PiperOrigin-RevId: 334728148
2020-09-30 19:25:00 -07:00
Akihiro Nitta
d707ae17e5
Merge branch 'master' into use-raise-from 2020-10-01 00:27:03 +09:00
Jake VanderPlas
40016cc47c Allow jax objects to be represented by multiple buffers 2020-09-29 11:53:17 -07:00
Akihiro Nitta
06170da69a
Use raise from 2020-09-30 01:20:00 +09:00
Peter Hawkins
a0e14b0552
Revert "Allow JAX objects to be represented by multiple buffers" 2020-09-29 09:26:11 -04:00
Lena Martens
ecad419cf3 Support grad with integer arguments.
- Add float0 and set-up at_least_vspace to return float0
values for int/bool primals
- Use Zero to wrap float0 tangents so they're correctly ignored in jvp
rules
- Add float0 handlers to XLA to support jit
- Fix convert_element_type and tie_in jvp rules
2020-09-28 19:07:04 +01:00
jax authors
fa1133885b Merge pull request #3983 from jakevdp:device-put-tuple
PiperOrigin-RevId: 334180017
2020-09-28 10:00:49 -07:00
jax authors
5b3cbc5e18 Merge pull request #4342 from google:improve-tracer-error
PiperOrigin-RevId: 333841912
2020-09-25 17:44:21 -07:00
Matthew Johnson
f66c2eefe1
Merge branch 'master' into improve-tracer-error 2020-09-25 15:13:14 -07:00
Jake VanderPlas
185590fcbf Use core.concrete_or_error() to improve errors in reductions 2020-09-25 14:18:46 -07:00
Jake VanderPlas
d1f80228e0 Allow jax objects to be represented by multiple buffers 2020-09-25 11:09:08 -07:00
Adam Paszke
e0d1b375fa Delete dead axis_index code
The primitive was moved to `lax_parallel.py` some time ago, so the one
in `core` should no longer be used. This is probably a result of a
botched rebase.
2020-09-22 13:08:38 +00:00
Matthew Johnson
3a77b2fac6 Improve a tracer error message
Previously, given this function:

```python
@jax.jit
def f(x,y):
  if x > y:
    return x
  else:
    return y
```

we'd get an error message like this (after #4038, improved to help with
omnistaging debugging):

```
...

While tracing the function f at tim.py:3, this value became a tracer due to JAX operations on these lines:

  operation c:bool[] = gt a:int32[] b:int32[]
    from line tim.py:5 (f)

...
```

But this message is buggy! In this case, the value is a tracer because
it has a data dependence on arguments to a jitted function.

After this change, we instead produce this error message:

```
...

While tracing the function f at tim.py:3, this concrete value was not available in Python because it depends on the value of the arguments to f at tim.py:3 at positions [0, 1], and the computation of these values is being staged out.

...
```

I'm eager to iterate with further improvements, but for now I want to
fix this buggy message.
2020-09-18 10:38:37 -07:00
Matthew Johnson
6614f94890
rename and simplify TypedJaxpr -> ClosedJaxpr (#4328)
rename and simplify TypedJaxpr -> ClosedJaxpr

This change:
* simplifies code that constructs TypedJaxprs/ClosedJaxprs (because
  in_avals / out_avals no longer need to be constructed), making them
  easier to work with;
* correspondingly rules out a class of errors (mismatches between
  invars/outvars and in_avals/out_avals);
* provides a more descriptive class name (ClosedJaxprs are like jaxprs
  but they're closed in that they are packaged with their constant
  values).

This is part 1 of an attempt to remove TypedJaxprs completely, or at
least significantly reduce our use of them. However, I'm not getting rid
of them entirely in this first step because it'd require bigger changes
(basically allowing all constants to be represented as literals, rather
than only scalars) that would not only touch a lot more code (jaxpr
formation, jaxpr-to-jaxpr transformations, control flow, XLA lowering)
but also might affect XLA lowering right before a conference deadline
(ICLR). Plus I'm trying to make big changes in smaller steps :)

Co-authored-by: George Necula <gcnecula@gmail.com>
2020-09-18 10:07:13 -07:00
Matthew Johnson
11007ba0e3
test eval_context works w/ and w/o omnistaging (#4325) 2020-09-17 09:57:43 -07:00
Srijan Saurav
40e20242db
Fix code quality issues (#4302)
Changes:
- Fix unnecessary generator
- Iterate dictionary directly instead of calling .keys()
- Remove global statement at the module level
- Use list() instead of a list comprehension
- Use with statement to open the file
- Merge isinstance calls
2020-09-17 09:21:18 -07:00
Matthew Johnson
b81c246a18
move the trace liveness check from #4312 (#4315) 2020-09-16 23:59:58 -07:00
Matthew Johnson
325d3bc71d
improve an escaped tracer error message (#4312)
* improve an escaped tracer error message

Before this commit, encountering an escaped tracer in a specific way
would lead to a bad internal error. This change
1. raises an UnexpectedTracerError instead, and
2. includes in the error message the user source line which created the
tracer.

* deflake

* replace _live propety with _assert_live method

Thanks @jekbradbury !
2020-09-16 15:59:50 -07:00
Matthew Johnson
2678a4647a
omnistaging on by default (#4038) 2020-09-15 08:06:46 -07:00
Adam Paszke
40fb01b4bd Extend axis env while translating the pmapped jaxpr to XLA
This is normally unnecessary, because the XLA translation usually
doesn't bind any of the primitives in the jaxpr, but this is not true in
case of scan! Its translation rule reevaluates the jaxpr as a function,
and if it contains collectives such as `axis_index` it can fail due to
axis being missing.
2020-09-11 17:56:32 +02:00
Adam Paszke
0aed1f4ddf Add more context to the axis_frame error message.
Some of the vmap and gmap collective tests have been failing on master
and I can't seem to be able to reproduce them locally. Hopefully, if
this happens again, this extra bit of information will be useful in
debugging the problem.
2020-09-07 16:25:30 +02:00
George Necula
634c6259df
More renaming of master to main in JAX internals (#4179) 2020-08-30 12:38:14 +03:00
Matthew Johnson
6b6789a53b
applied simple find+sed for 'master' -> 'main' (#4174)
* applied simple find+sed for 'master' -> 'main'

* Rename master->main in JAX API and internals (#4178)

* Started with #4174 
* Renamed Trace.master to Trace.main
* Renamed core.new_master and core.new_base_master

Co-authored-by: George Necula <gcnecula@gmail.com>
2020-08-30 11:16:51 +03:00
Adam Paszke
a33f4dd8c8
Add support for axis_index inside vmap (#4168)
Also, reorganize the code to put all `axis_index` related functions in
`lax_parallel.py`, next to all other parallel collectives.
2020-08-28 20:03:39 +02:00
Adam Paszke
7210d6f5d0 Add support for binding axis_name in gmap
This allows executing collectives over the gmapped axes. This requires
some extra manipulation of the gmapped jaxpr, since gmap exposes a
single logical axis name, but evaluates the program using multiple
"physical" axes.

This also fixes some bugs around handling `multiple_returns` in
vmap collective implementation.
2020-08-28 14:42:01 +02:00
Tom Hennigan
f0fb7d0925
Use omnistaging env var even when not using absl flags for config. (#4152) 2020-08-26 14:06:27 -07:00
Sharad Vikram
774b5f688e Remove frame check assertion in extend_axis_env. 2020-08-24 21:13:30 -07:00
Matthew Johnson
66a02b6971 only construct one axis_index_p primitive
Before this change, there were two versions, one used with omnistaging
and one without. But that made bookkeeping hard and buggy. This change
defines the axis_index_p primitive in core.py. Some of its rules are
still changed when omnistaging is enabled.
2020-08-21 17:43:15 -07:00
Matthew Johnson
b89223629e
remove check for TypedJaxpr literals arent tracers (#4096)
In the original usage of TypedJaxpr, literals could not be tracers
because they were only produced by initial-style transformations of
jaxprs. But now TypedJaxpr is used in several other ways, e.g. in
make_jaxpr, and moreover its avals are redundant. It should probably be
renamed ClosedJaxpr since it mainly serves to package a jaxpr together
with its constant arrays. This check was limiting the utility of
TypedJaxpr, and it was only added relatively recently anyway.
2020-08-18 21:04:14 -07:00
Stephan Hoyer
decd760020
Add experimental __array_module__ method (#4076)
* Add experimental __array_module__ method

xref https://github.com/google/jax/issues/1565

`__array_module__` (see [NEP 37](https://numpy.org/neps/nep-0037-array-module.html))
is an experimental alternative to `__array_function__` and `__array_ufunc__`
for "duck array" compatibility with NumPy that promises to be much less
invasive.

Example usage:

```python
import numpy as np

def duckarray_stack(arrays):
    """This "stack" function should work with any array library, including JAX."""
    npx = np.get_array_module(*arrays)
    arrays = [npx.asarray(arr) for arr in arrays]
    shapes = {arr.shape for arr in arrays}
    if len(shapes) != 1:
        raise ValueError('all input arrays must have the same shape')
    expanded_arrays = [arr[npx.newaxis, ...] for arr in arrays]
    return npx.concatenate(expanded_arrays, axis=0)
```

Support for this protocol has *not* yet been implemented in NumPy, but it can
be tested with https://github.com/seberg/numpy-dispatch.

My reasoning for merging it into JAX (on an experimental basis with no
guarantees, of course) is that:

1. It's not invasive -- the implementation is small and self-contained.
2. No backwards compatibility issues. Unlike `__array_function__` and
   `__array_ufunc__`, `__array_module__` will always require an explicit
   opt-in by libraries that use it by calling `get_array_module()`.
2. Other NumPy developers
   [want evidence](https://github.com/numpy/numpy/pull/16935#issuecomment-673951287)
   that this is actually feasible.
3. Scikit-Learn developers like @thomasjpfan are interested in exploring
   supporting scikit-learn on top of NumPy-like libraries like JAX, and
   experimental support for this protocol will make that easier.

Note: this PR does add `numpy-dispatch` as a optional testing requirement in
order to verify that this works. If desired, we could remove this from CI, but
installing numpy-dispatch (and its build requirement Cython) appears to only
add a few seconds of build time.

* don't explicitly list cython

* remove UnshpaedArray from _JAX_ARRAY_TYPES

* Remove incorrect note about metaclasses

* remove unnecessary numpy_dispatch.ensure_dispatching()
2020-08-18 09:40:57 -07:00
Jake Vanderplas
8923bab50d
fixes for pytype (#4068) 2020-08-14 12:53:02 -07:00
Jake Vanderplas
c311fb77fb
Make it possible to override raise_to_shaped for new types (#4064) 2020-08-14 11:51:19 -07:00
Adam Paszke
b75bae6437
Initial version of vmap collectives (#4005)
This adds support for the basic (associative and commutative)
collectives to vmap. Supporting more complex collectives will
require some more complicated rules. Also, at the moment it is not
possible to use collectives inside `custom_vjp` rules which we might
want to fix in the future.

This feature is also omnistaging-only.

Co-authored-by: Matthew Johnson <mattjj@google.com>
2020-08-14 18:22:04 +02:00
Julius Kunze
2e873b7d05
Fix jnp.right_shift incorrect on unsigned ints (#3958) 2020-08-05 18:36:46 -07:00
Matthew Johnson
ff96de935b
add dummy eval context (#3932) 2020-07-31 22:20:58 -07:00
Roy Frostig
cd64d2eed5 typecheck scan and cond params 2020-07-31 15:58:13 -07:00