241 Commits

Author SHA1 Message Date
jax authors
7c294e62f4 Copybara import of the project:
--
7342318774c6f1195f0e238f1209425109ea8944 by Matthew Johnson <mattjj@google.com>:

check for __jax_array__ method for conversion

--
6742016382b0511f5ac9ec21f67d2122a9f37cb7 by Matthew Johnson <mattjj@google.com>:

fix typo

--
5eb36855e53d8d4e81e281d08dc9264d2671f21f by Matthew Johnson <mattjj@google.com>:

ensure some jnp funs duck-type with __jax_array__

PiperOrigin-RevId: 347763582
2020-12-15 23:13:29 -08:00
Matthew Johnson
7342318774 check for __jax_array__ method for conversion 2020-12-14 17:09:25 -08:00
Adam Paszke
ca8028950e Fix pmap compilation cache regressions from #4904.
AD didn't use `HashableFunction` enough, tripping up the compilation
cache. I've also used the occasion to make function hashing a little
safer by including the Python bytecode of the wrapped function as part
of the key.
2020-12-02 14:40:45 +00:00
Matthew Johnson
8057cf919e simplify vmap collectives from two sets of rules to one
Specifically we:
1. remove the need for split_axis rules in batching.py, and instead just
rely on collective rules (namely to handle vectorizing over a single
named axis even if the collective is applied over multiple named axes)
2. simplify BatchTrace.process_primitive so that we don't pass tracers
into rules and rely on a subtle recursion

This change breaks all_to_all when used with multiple axis names, and in
particular it breaks all_to_all given the current gmap/xmap lowering
strategy of substituting multiple axis names in place of single axis
names. We believe we can replicate the previous logic with the new rule
organization, but we're leaving that for follow-up work because it's
tricky, and because we might end up changing lowering strategies not to
require axis substitution in the same way.
2020-11-25 10:15:21 -08:00
jax authors
c7057d5fb1 Merge pull request #5005 from apaszke:xmap-primitive
PiperOrigin-RevId: 344263137
2020-11-25 09:07:27 -08:00
Adam Paszke
5ee2de1675 Forbid pmap/soft_pmap/sharded_jit inside xmap 2020-11-25 13:47:05 +00:00
Adam Paszke
5879967c25 Add support for non-zero (but still not-None) out_axes in pmap
Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`),
but its semantics would match the specification of `out_axes=0` (i.e.
all outputs should be stacked along the first axis). This patch makes it
possible to specify non-zero values for out_axes, but more importantly
it lays down the groundwork for `xmap` which will have to use some
extremely similar (if not the same) code paths.

One thing to note is that when I started this implementation I was also
planning to add support for `out_axes=None`, which would allow us to
stop using the `unbroadcast` hack, and most of the code is written with
that in mind. Unfortunately it turned out that the correct
implementation of the transpose rule for maps that do allow unmapped
outputs would require me to pretty much simulate what avals-with-names
is supposed to achieve. Technically replicated outputs should work
today, for as long as the user does not do reverse-mode AD of `pmap`.
But I decided that it's better to just disable them altogether until we
can get the full and correct behavior.

* Implementation details *

This patch is significantly more involved than the one that implemented
general `in_axes` support. That previous one at least had the foundation
of `mapped_invars` which already behaved pretty similarly to general
`in_axes`. From a quick glance one might think that `out_axes` should
behave similarly to `in_axes`, but it turns out that this is not the
case, at least not if we're interested in keeping those primitives
final-style.

** Thunking **

The biggest difficulty with handling `out_axes` in final style
primitives is that we want to treat them as a prefix of the output
pytree, but we don't know the structure of the output pytree until the
user function is evaluated! And the user function is not evaluated until
we've applied all transforms and reached the impl rule! The solution to
this problem is "straightforward": instead of putting `out_axes` as a
primitive parameter, we bundle an `out_axes_thunk` which can only be
called successfully after the wrapped function has been executed. The
thunk returns a list of flat `out_axes`, expanded to the output pytree.
However, the thunking presents us with two problems:

*** Transformations ***

Each transformation that modifies the number of outputs needs to ensure
that the thunk is updated to reflect the new values. To make things
worse a lot of the transforms can learn the number of added outputs
_only after the wrapped function is evaluated_, which leads to the
following "time travel" pattern that can be found in most `Trace`s:
```py
@lu.transformation_with_aux
def compute_output_statistic(*args, **kwargs):
  outputs = yield args, kwargs
  yield outputs, compute_statistic(outputs)
wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun)
def new_out_axes_thunk():
  old_out_axes = params['out_axes_thunk']()
  return compute_new_out_axes(old_out_axes(), output_statistic())
primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk))
```
The reason why we have to structure the code this way is that we can
only specify a new `out_axes_thunk` before we bind the primitive, but we
need the outputs of bind to know how to update the `out_axes_thunk`. To
make things worse, the implementation of `bind` is allowed to make a
call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_.
This means that we cannot compute the output statistic in the
implementation of the transformation, but we have to use an extra
`lu.transformation_with_aux` for that (this populates the statistic
store immediately after `wrapped_fun` is evaluated).

The `compute_statistic` function depends on the transform in question.
E.g. in the JVP trace it counts the number of non-zero tangent results.

The situation is of course further complicated when we take
`post_process_map` into account. The new `process_env_traces` now always
sets up this funny time travel trampoline just in case it ends up being
necessary, and `post_process_map` is now expected to return `(outputs,
(todo, out_axes_transform))` instead of just `(outputs, todo)`.

*** Compilation cache ***

Because the `out_axes_thunk`s are now arguments to a _global_
compilation cache (in the form of `lu.cache` decorator on
`parallel_callable`), we have to ensure that they implement `hash` and
`==`. This is what forces us to add some slightly weird helpers such as
`_hashable_function` and `_ignore_elem_list`. The code that uses those
makes an assumption that the output pytree depends deterministically on
the identity of the wrapped function, which I think is in line with
general JAX assumptions. Otherwise the cache would depend on the
identity of the thunk, which changes with every function invocation.

Relaxing the global constraint on the cache (e.g. allowing each
`pmap(f)` instance to have a separate cache) would make this easier too.

* Why final style? *

Now, making the primitives initial-style would remove the necessity for
thunking, because we could have obtained the output pytree right when
the function is wrapped. I assumed there is a good argument for making
`pmap` pretend that it's a final-style primitive, but I'm not sure why
that is? I hope it's something better than just avoiding a single jaxpr
tracing.
2020-11-24 17:11:38 +00:00
Adam Paszke
2494e0c339 Add XLA lowering for xmap
This should allow us to try out xmap not only in a simulation (i.e.
faking the devices using vmap, which we still support), but also on real
hardware.

Limitations:
- No compilation caching yet
- Nested xmaps not supported yet
- Transforms (AD, vmap, etc.) of xmaps not supported yet

Benefits:
- An xmap over multiple mesh axes already implements a more efficient
  lowering than the one used for nested pmaps.

The `resources` context-manager is now called `fake_resources`, while
real meshes can be defined in a specific context using the
`mesh(devices, axis_names)` manager. `devices` is supposed to be an
`ndarray` of JAX device objects (e.g. obtained from `jax.devices()`),
while `axis_names` should be a tuple of length matching the rank of
`devices` and specifying mesh axis names.

For concrete examples see the changes in `gmap_tests.py`.

In principle the current version of the code should also work in a
multi-host setting, but I haven't tested it just yet.
2020-11-24 11:13:49 +00:00
Peter Hawkins
84c723fc9e [JAX] Move pprint_util into jax._src.
PiperOrigin-RevId: 343279975
2020-11-19 06:42:19 -08:00
jax authors
69c920c601 Merge pull request #4796 from qiuminxu:add_jax_named_call
PiperOrigin-RevId: 342951787
2020-11-17 14:52:52 -08:00
Roy Frostig
78c6e4e5e5 fix check_jaxpr docstring 2020-11-13 18:00:33 -08:00
jax authors
83a38f4f3b Merge pull request #4854 from j-towns:tidy-stack
PiperOrigin-RevId: 342264328
2020-11-13 08:07:45 -08:00
Matthew Johnson
8b006f6a90 add correct annotations to core.TraceStack 2020-11-13 07:23:02 -08:00
Qiumin Xu
0f8ea37556 Update core.py 2020-11-12 17:36:46 -08:00
Qiumin Xu
31600aac62 Add named_call public API.
Move named_call_p to core.py from lax.py.
Also move the translation rule to jax/interpreters/xla.py where the core_call translation rule is.
2020-11-12 17:32:01 -08:00
Adam Paszke
a5bc7353de Add support for pmap in_axes other than 0 and None
... and in map primitives in general (which is why the patch touches
most traces).

This also fixes a bug in the transpose rule for map primitives, which
would fail to adjust the aval associated with zeros returned from the
map body.
2020-11-10 18:35:28 +00:00
jax authors
bdd7915661 Internal change
PiperOrigin-RevId: 341644256
2020-11-10 10:12:27 -08:00
Adam Paszke
6914058cbe Add support for pmap in_axes other than 0 and None
... and in map primitives in general (which is why the patch touches
most traces).

This also fixes a bug in the transpose rule for map primitives, which
would fail to adjust the aval associated with zeros returned from the
map body.
2020-11-10 13:35:23 +00:00
Jamie Townsend
b8920a11c3 Rm old attribute annotations from TraceStack 2020-11-10 11:10:06 +00:00
Peter Hawkins
7efc1dbc94 [JAX] Move source_info_util into jax._src.
TFP uses source_info_util, so we leave a forwarding stub until we can update TFP.

PiperOrigin-RevId: 340698612
2020-11-04 11:54:24 -08:00
Peter Hawkins
81b6cd29ff [JAX] Move traceback_util.py into jax._src.
traceback_util is a JAX-private API.

PiperOrigin-RevId: 340659195
2020-11-04 09:02:59 -08:00
Adam Paszke
b85e605ff1 Add support for collectives in xmap 2020-11-03 17:52:18 +00:00
jax authors
d158647c83 Merge pull request #4706 from apaszke:vmap-collectives-in-scan
PiperOrigin-RevId: 339646941
2020-10-29 05:11:23 -07:00
Roy Frostig
5d50e19364 add path exclusion opt-in to filtered stack traces and use it throughout the codebase 2020-10-26 12:31:19 -07:00
Adam Paszke
6348a99fb4 Add support for vmap collectives in control flow primitives
All initial style primitives currently use `batch_jaxpr` in their
batching rules, but that function hasn't been updated to support
axis_name when I added support for vmap collectives.
2020-10-26 12:09:18 +00:00
George Necula
c5b983c0de
Update jax/core.py
Co-authored-by: Roy Frostig <froystig@users.noreply.github.com>
2020-10-21 07:53:37 +03:00
George Necula
cb591eb77f
Update jax/core.py
Co-authored-by: Roy Frostig <froystig@users.noreply.github.com>
2020-10-20 22:28:22 +03:00
George Necula
09653bf544 Ensure that check_jaxpr is done with abstract values
Prior to this it was possible, e.g., for code that contains a Literal,
such as  to result in FLOPS during checking.

The assertion is broken by many tests unless we raise_to_shape for Literals.

I have timed the checks on my laptop and I do not see a reduction in the
total test time.
2020-10-20 11:08:41 +03:00
Matthew Johnson
a493a0f43d ensure ConcreteArray equality stays in Python 2020-10-16 18:21:01 -07:00
jax authors
4a20eea828 Copybara import of the project:
--
609f6f3e16d21fed34cc5269c54a0d78ac44a8bc by Matthew Johnson <mattjj@google.com>:

fix custom_jvp/vjp closure issues

PiperOrigin-RevId: 337457689
2020-10-16 00:21:32 -07:00
Jean-Baptiste Lespiau
b13775f464 Enrich the error messages with the bound names that are available.
The user often do not know whether it's not the correct name, or whether it was not defined, etc. It's easier to get this information when debugging.
2020-10-12 20:33:27 +02:00
George Necula
0213efdf4d [jax2tf] Port jax2tf to use omnistaging
The main change is that we use `core.new_base_main` to use an
omnistaging-based tracer. This has the benefit that we can
convert to TF even functions with no arguments (previously
they would be constant-folded by JAX prior to the conversion).

We also add an explicit error if the jax2tf.convert transformation
is nested under other JAX transformations.
2020-10-09 18:42:28 +03:00
Matthew Johnson
52fe026c09 optimize scan partial_eval to fix #4510
fixes #4510
2020-10-08 20:34:34 -07:00
Jake VanderPlas
6393349783 raise_to_shaped: preserve weak_type by default 2020-10-08 11:53:52 -07:00
Roy Frostig
e7979258ee equation context for undefined var reads in jaxpr typechecker 2020-10-05 12:29:43 -07:00
Lena Martens
cc0114a0a9 Fix dtype behavior with float0s in CustomVJP. 2020-10-01 15:17:51 +01:00
jax authors
69fda9ecb9 Merge pull request #4039 from LenaMartens:changelist/325216264
PiperOrigin-RevId: 334728148
2020-09-30 19:25:00 -07:00
Akihiro Nitta
d707ae17e5
Merge branch 'master' into use-raise-from 2020-10-01 00:27:03 +09:00
Jake VanderPlas
40016cc47c Allow jax objects to be represented by multiple buffers 2020-09-29 11:53:17 -07:00
Akihiro Nitta
06170da69a
Use raise from 2020-09-30 01:20:00 +09:00
Peter Hawkins
a0e14b0552
Revert "Allow JAX objects to be represented by multiple buffers" 2020-09-29 09:26:11 -04:00
Lena Martens
ecad419cf3 Support grad with integer arguments.
- Add float0 and set-up at_least_vspace to return float0
values for int/bool primals
- Use Zero to wrap float0 tangents so they're correctly ignored in jvp
rules
- Add float0 handlers to XLA to support jit
- Fix convert_element_type and tie_in jvp rules
2020-09-28 19:07:04 +01:00
jax authors
fa1133885b Merge pull request #3983 from jakevdp:device-put-tuple
PiperOrigin-RevId: 334180017
2020-09-28 10:00:49 -07:00
jax authors
5b3cbc5e18 Merge pull request #4342 from google:improve-tracer-error
PiperOrigin-RevId: 333841912
2020-09-25 17:44:21 -07:00
Matthew Johnson
f66c2eefe1
Merge branch 'master' into improve-tracer-error 2020-09-25 15:13:14 -07:00
Jake VanderPlas
185590fcbf Use core.concrete_or_error() to improve errors in reductions 2020-09-25 14:18:46 -07:00
Jake VanderPlas
d1f80228e0 Allow jax objects to be represented by multiple buffers 2020-09-25 11:09:08 -07:00
Adam Paszke
e0d1b375fa Delete dead axis_index code
The primitive was moved to `lax_parallel.py` some time ago, so the one
in `core` should no longer be used. This is probably a result of a
botched rebase.
2020-09-22 13:08:38 +00:00
Matthew Johnson
3a77b2fac6 Improve a tracer error message
Previously, given this function:

```python
@jax.jit
def f(x,y):
  if x > y:
    return x
  else:
    return y
```

we'd get an error message like this (after #4038, improved to help with
omnistaging debugging):

```
...

While tracing the function f at tim.py:3, this value became a tracer due to JAX operations on these lines:

  operation c:bool[] = gt a:int32[] b:int32[]
    from line tim.py:5 (f)

...
```

But this message is buggy! In this case, the value is a tracer because
it has a data dependence on arguments to a jitted function.

After this change, we instead produce this error message:

```
...

While tracing the function f at tim.py:3, this concrete value was not available in Python because it depends on the value of the arguments to f at tim.py:3 at positions [0, 1], and the computation of these values is being staged out.

...
```

I'm eager to iterate with further improvements, but for now I want to
fix this buggy message.
2020-09-18 10:38:37 -07:00
Matthew Johnson
6614f94890
rename and simplify TypedJaxpr -> ClosedJaxpr (#4328)
rename and simplify TypedJaxpr -> ClosedJaxpr

This change:
* simplifies code that constructs TypedJaxprs/ClosedJaxprs (because
  in_avals / out_avals no longer need to be constructed), making them
  easier to work with;
* correspondingly rules out a class of errors (mismatches between
  invars/outvars and in_avals/out_avals);
* provides a more descriptive class name (ClosedJaxprs are like jaxprs
  but they're closed in that they are packaged with their constant
  values).

This is part 1 of an attempt to remove TypedJaxprs completely, or at
least significantly reduce our use of them. However, I'm not getting rid
of them entirely in this first step because it'd require bigger changes
(basically allowing all constants to be represented as literals, rather
than only scalars) that would not only touch a lot more code (jaxpr
formation, jaxpr-to-jaxpr transformations, control flow, XLA lowering)
but also might affect XLA lowering right before a conference deadline
(ICLR). Plus I'm trying to make big changes in smaller steps :)

Co-authored-by: George Necula <gcnecula@gmail.com>
2020-09-18 10:07:13 -07:00