57 Commits

Author SHA1 Message Date
Yash Katariya
fbc05ee5ac Remove global_arg_shapes from pmap since it was only used for sharded_jit and sharded_jit was removed from JAX a long time ago
PiperOrigin-RevId: 520356179
2023-03-29 09:23:22 -07:00
Matthew Johnson
a964dc3b9a simpler pretty-print for pjit, tweak custom pp rule signature 2023-02-09 12:45:51 -08:00
Yash Katariya
e4d551a217 Remove the doctest skip now that jit and pjit have been merged
PiperOrigin-RevId: 508162840
2023-02-08 13:09:53 -08:00
yashkatariya
2cfec044bf Fix the jaxpr after jit-pjit merge 2023-02-08 09:12:01 -08:00
yashkatariya
d3eef935f7 Fix the jaxpr after jit-pjit merge 2023-02-08 08:52:57 -08:00
Lena Martens
caf4f7b3f7 Lift global_axis calculation from lowering in pxla.py to api.py.
Add an "explicit_global_axis_size" arg. `global_axis` used to be set to `None`
when the user did not provide an explicit axis size. After this change,
`global_axis` should never be set to `None` internally, and always contain the
size of the global axis. It's still useful to thread the information that the
user has provided an explicit axis size so we can throw explicit errors in
`pxla` when explicit axis sizes are not allowed.

Why do we need to do this? We only go down the lowering path when calling
`pmap`s impl rule (while executing or final-style transforming), but not when
initial-style transforming. The global_axis size should be computed earlier,
such that it is available for initial-style transformations/primitives, e.g. if
we round-trip a multi-host pmap computation through make_jaxpr and eval_jaxpr.

We have tests for "initial-style transform of a `pmap`", but no such test for
_multi-host_ `pmap`! Alors, this bug went unnoticed.
#13545 makes `checkify` initial-style, and because `checkify-of-pmap` is a
valid way to check a `pmap`, an internal multi-host test uncovered this bug.

PiperOrigin-RevId: 499877003
2023-01-05 07:54:53 -08:00
Jake VanderPlas
8a3bfe0300 DOC: add references for haskell-style signatures 2022-05-13 12:35:28 -07:00
Matthew Johnson
8430deda3e custom pp_eqn rules, simpler xla_call print 2021-11-23 15:52:52 -08:00
Matthew Johnson
abbf78b5c3 generalize jaxpr simplification machinery
also:
* fix jit invariance bug around weak types
* elide trivial broadcasts

This started as an attempt to simplify some jaxpr pretty-prints, by (1)
eliding some convert_element_type applications that I thought were
unnecessary and (2) eliding some trivial broadcasts.

But it turned out that we were actually pruning more
convert_element_types than we should! In particular, see
test_weak_type_jit_invariance; that test fails on the main branch even
if we add the fixes in DynamicJaxprTrace.new_const, because [this
logic](b53a174042/jax/interpreters/partial_eval.py (L1225))
was not paying attention to weak types and hence clobbered them.

In addition to fixing those bugs that turned up (the changes in
DynamicJaxprTrace, and in what is now _convert_elt_type_fwd_rule), this
PR generalizes the jaxpr simplification machinery so as not to be a
couple special cases on convert_element_type_p. Insetad, we have tables
of rules! How we love them.

These rule signatures should let us add simplifications like forwarding
variables through calls and other higher-order primitives. That's all
future work though.
2021-11-19 09:00:59 -08:00
Peter Hawkins
48bbdbc890 Change jax.core.DropVar to be a non-singleton.
Previously jax.core.DropVar was a singleton value (jax.core.dropvar) whose type was always jax.core.AbstractUnit. However, this type is misleading: a DropVar is an equation output, and typically we would expect it to have an array type. In particular, the unit type confuses new-style translation rules that expect to use the output aval on an equation as part of the lowering logic.

Instead, change DropVar to be a non-singleton subclass of Var instead with a flexible choice of aval.

PiperOrigin-RevId: 404071001
2021-10-18 15:02:54 -07:00
Peter Hawkins
d4023508a4 Uniquify variable names globally within a jaxpr.
It is confusing when the same name is shadowed within an inner lambda expression. Use globally unique variable names in each pretty-printed jaxpr.
2021-10-01 12:49:47 -04:00
Peter Hawkins
5fa4613e99 Adds a Wadler-Lindig pretty printer.
Changes jaxpr printing to use it.
2021-09-27 21:09:24 -04:00
Peter Hawkins
1163e218e8 Attempt to land https://github.com/google/jax/pull/6400 again.
This PR changes `jax.numpy.array()` to avoid creating any on-device arrays during tracing. As a consequence, calls to `jnp.array()` in a traced context, such as `jax.jit` will always be staged into the trace.

This change may break code that depends on the current (undocumented and unintentional) behavior of `jnp.array()` to perform shape or index calculations that must be known statically (at trace time). The workaround for such cases is to use classic NumPy to perform shape/index calculations.

PiperOrigin-RevId: 398008511
2021-09-21 09:06:40 -07:00
elliotwaite
7392a57b75 DOC: many small fixes 2021-08-04 16:55:13 -07:00
Parker Schuh
92246017d2
Revert "Use convert_element_type instead of device_put_raw." 2021-05-06 20:19:23 -07:00
Parker Schuh
9d3e535ad2
Merge branch 'master' into convert_element 2021-05-06 13:18:01 -07:00
Matthew Johnson
7ec0b40173 Roll-forward of #6584, which broke internal tests.
PiperOrigin-RevId: 371839298
2021-05-03 21:41:23 -07:00
jax authors
75b00a1235 Copybara import of the project:
--
3c400a3e588abf9e2259119c50343cba6f3477f1 by Matthew Johnson <mattjj@google.com>:

add 'inline' option to xla_call for jaxpr inlining

--
fe297e39ca37896b75d7943b9b77c0b53fad13ee by Matthew Johnson <mattjj@google.com>:

add 'inline' to jit docstring

--
ff6866c4b3757cde66fe659c2f27d8aeff024e8f by Matthew Johnson <mattjj@google.com>:

new_sublevel in jax2tf

PiperOrigin-RevId: 371542778
2021-05-01 22:18:39 -07:00
Matthew Johnson
3c400a3e58 add 'inline' option to xla_call for jaxpr inlining 2021-04-28 19:38:15 -07:00
Matthew Johnson
ba9233b9b6 prune trivial convert_element_types from jaxprs
also add a test for not performing H2D transfers while tracing jnp.array
2021-04-22 12:46:26 -07:00
Jake VanderPlas
123dfb9bf7 DOC: add 'jaxpr' to the glossary 2021-03-02 11:57:50 -08:00
Jake VanderPlas
41b7a0f770 Re-land #4850 weak types change 2021-02-09 09:07:52 -08:00
Adam Paszke
f86bf12b5a Add support for axis names in jnp.{sum,min,max}
Similarly to `jnp.einsum`, whenever we encounter an extension to the
positional NumPy API (in the case of reductions, the extension is
whenever a non-integer axis is specified), we reroute the call to a
parallel primitive instead of the standard lax reductions.

Note that this makes the parallel primitives implement a strict subset
of functionality of the lax reductions so in the future (when we decide
that we want axes to be truly first class) we can always swap out the
implementation for the parallel version. But, it makes sense to keep
them separate for the ease of prototyping in the near future.
2021-02-01 11:41:05 +00:00
Jake VanderPlas
af6da229da DOC: fix some minor formatting issues 2021-01-28 15:20:02 -08:00
Łukasz Lew
9dccf567ce
Clarify tracking
Clarify tracing a bit and use wording that does not suggest that JAX executed python program.
2021-01-25 17:16:29 -08:00
Matthew Johnson
3dee321fb8 rollback of #4850 2020-12-23 11:01:58 -08:00
Jake VanderPlas
c63097bc90 Add weak_type argument to convert_element_type_p 2020-12-10 11:10:21 -08:00
Matthew Johnson
8ed7be4906 make convert_element_type_p not require old_dtype
Previously we needed to add old_dtype to the primitive's parameters so
that it could be transposed. However, now that avals information is
available in more places (in particular, attached to the UndefinedPrimal
instances we use to indicate inputs with respect to which we are
transposing), we don't need that kind of a hack!

This is a follow-up to #2410.
2020-12-03 11:49:43 -08:00
Matthew Johnson
50cb604f2e fix doc test failure 2020-11-25 10:15:21 -08:00
Adam Paszke
5879967c25 Add support for non-zero (but still not-None) out_axes in pmap
Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`),
but its semantics would match the specification of `out_axes=0` (i.e.
all outputs should be stacked along the first axis). This patch makes it
possible to specify non-zero values for out_axes, but more importantly
it lays down the groundwork for `xmap` which will have to use some
extremely similar (if not the same) code paths.

One thing to note is that when I started this implementation I was also
planning to add support for `out_axes=None`, which would allow us to
stop using the `unbroadcast` hack, and most of the code is written with
that in mind. Unfortunately it turned out that the correct
implementation of the transpose rule for maps that do allow unmapped
outputs would require me to pretty much simulate what avals-with-names
is supposed to achieve. Technically replicated outputs should work
today, for as long as the user does not do reverse-mode AD of `pmap`.
But I decided that it's better to just disable them altogether until we
can get the full and correct behavior.

* Implementation details *

This patch is significantly more involved than the one that implemented
general `in_axes` support. That previous one at least had the foundation
of `mapped_invars` which already behaved pretty similarly to general
`in_axes`. From a quick glance one might think that `out_axes` should
behave similarly to `in_axes`, but it turns out that this is not the
case, at least not if we're interested in keeping those primitives
final-style.

** Thunking **

The biggest difficulty with handling `out_axes` in final style
primitives is that we want to treat them as a prefix of the output
pytree, but we don't know the structure of the output pytree until the
user function is evaluated! And the user function is not evaluated until
we've applied all transforms and reached the impl rule! The solution to
this problem is "straightforward": instead of putting `out_axes` as a
primitive parameter, we bundle an `out_axes_thunk` which can only be
called successfully after the wrapped function has been executed. The
thunk returns a list of flat `out_axes`, expanded to the output pytree.
However, the thunking presents us with two problems:

*** Transformations ***

Each transformation that modifies the number of outputs needs to ensure
that the thunk is updated to reflect the new values. To make things
worse a lot of the transforms can learn the number of added outputs
_only after the wrapped function is evaluated_, which leads to the
following "time travel" pattern that can be found in most `Trace`s:
```py
@lu.transformation_with_aux
def compute_output_statistic(*args, **kwargs):
  outputs = yield args, kwargs
  yield outputs, compute_statistic(outputs)
wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun)
def new_out_axes_thunk():
  old_out_axes = params['out_axes_thunk']()
  return compute_new_out_axes(old_out_axes(), output_statistic())
primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk))
```
The reason why we have to structure the code this way is that we can
only specify a new `out_axes_thunk` before we bind the primitive, but we
need the outputs of bind to know how to update the `out_axes_thunk`. To
make things worse, the implementation of `bind` is allowed to make a
call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_.
This means that we cannot compute the output statistic in the
implementation of the transformation, but we have to use an extra
`lu.transformation_with_aux` for that (this populates the statistic
store immediately after `wrapped_fun` is evaluated).

The `compute_statistic` function depends on the transform in question.
E.g. in the JVP trace it counts the number of non-zero tangent results.

The situation is of course further complicated when we take
`post_process_map` into account. The new `process_env_traces` now always
sets up this funny time travel trampoline just in case it ends up being
necessary, and `post_process_map` is now expected to return `(outputs,
(todo, out_axes_transform))` instead of just `(outputs, todo)`.

*** Compilation cache ***

Because the `out_axes_thunk`s are now arguments to a _global_
compilation cache (in the form of `lu.cache` decorator on
`parallel_callable`), we have to ensure that they implement `hash` and
`==`. This is what forces us to add some slightly weird helpers such as
`_hashable_function` and `_ignore_elem_list`. The code that uses those
makes an assumption that the output pytree depends deterministically on
the identity of the wrapped function, which I think is in line with
general JAX assumptions. Otherwise the cache would depend on the
identity of the thunk, which changes with every function invocation.

Relaxing the global constraint on the cache (e.g. allowing each
`pmap(f)` instance to have a separate cache) would make this easier too.

* Why final style? *

Now, making the primitives initial-style would remove the necessity for
thunking, because we could have obtained the output pytree right when
the function is wrapped. I assumed there is a good argument for making
`pmap` pretend that it's a final-style primitive, but I'm not sure why
that is? I hope it's something better than just avoiding a single jaxpr
tracing.
2020-11-24 17:11:38 +00:00
Skye Wanderman-Milne
4e56cf965a Add support for multi-host partitioning when using pmap(sharded_jit).
This extends the pmap logic in a way similar to
https://github.com/google/jax/pull/4746. The new arguments to
sharded_jit specifying the local partitioning can be reused by pmap,
but with one wrinkle: the pmap implementation needs to trace its jaxpr
to "see" the sharded_jit and get these values, but it needs to know
the global aval shapes in order to correctly trace through the
sharded_jit. For now, we simply add this information as a new
"global_arg_shapes" argument to pmap. Ideally we'll replace this with
a more elegant solution, e.g. global-view device arrays.
2020-11-20 12:50:46 -08:00
Adam Paszke
a5bc7353de Add support for pmap in_axes other than 0 and None
... and in map primitives in general (which is why the patch touches
most traces).

This also fixes a bug in the transpose rule for map primitives, which
would fail to adjust the aval associated with zeros returned from the
map body.
2020-11-10 18:35:28 +00:00
jax authors
bdd7915661 Internal change
PiperOrigin-RevId: 341644256
2020-11-10 10:12:27 -08:00
Adam Paszke
6914058cbe Add support for pmap in_axes other than 0 and None
... and in map primitives in general (which is why the patch touches
most traces).

This also fixes a bug in the transpose rule for map primitives, which
would fail to adjust the aval associated with zeros returned from the
map body.
2020-11-10 13:35:23 +00:00
Matthew Johnson
6614f94890
rename and simplify TypedJaxpr -> ClosedJaxpr (#4328)
rename and simplify TypedJaxpr -> ClosedJaxpr

This change:
* simplifies code that constructs TypedJaxprs/ClosedJaxprs (because
  in_avals / out_avals no longer need to be constructed), making them
  easier to work with;
* correspondingly rules out a class of errors (mismatches between
  invars/outvars and in_avals/out_avals);
* provides a more descriptive class name (ClosedJaxprs are like jaxprs
  but they're closed in that they are packaged with their constant
  values).

This is part 1 of an attempt to remove TypedJaxprs completely, or at
least significantly reduce our use of them. However, I'm not getting rid
of them entirely in this first step because it'd require bigger changes
(basically allowing all constants to be represented as literals, rather
than only scalars) that would not only touch a lot more code (jaxpr
formation, jaxpr-to-jaxpr transformations, control flow, XLA lowering)
but also might affect XLA lowering right before a conference deadline
(ICLR). Plus I'm trying to make big changes in smaller steps :)

Co-authored-by: George Necula <gcnecula@gmail.com>
2020-09-18 10:07:13 -07:00
Matthew Johnson
2678a4647a
omnistaging on by default (#4038) 2020-09-15 08:06:46 -07:00
Alex Minnaar
64bead2093
fixing typo (#4273)
I assume "...one of more type parameters..." was intended to read "...one or more type parameters..."
2020-09-12 13:10:01 -07:00
Jake Vanderplas
2a33b3d388
fix documentation typo (#4252) 2020-09-10 11:23:29 -07:00
Jake Vanderplas
05904faf0f
Change onp/np to np/jnp in docs & notebooks (#3760) 2020-07-15 13:17:38 -07:00
Roy Frostig
8a62a9b654
block-unrolled scan primitive implementation (#3738)
* block-unrolled scan implementation, via optional `_unroll` scan parameter

* index statically in the inlined path of lax.scan

* make `unroll` a required scan parameter, and test that it unrolls
2020-07-15 14:00:50 -04:00
8bitmp3
242b382bab
Remove a deprecated reference to testExamplesJaxprDoc in Understanding Jaxpr (#3680) 2020-07-07 11:29:44 -07:00
igorwilbert
e5d4ca31a8
Fix typo understanding jaxprs page on readthedocs (#3513) 2020-06-22 12:31:08 -07:00
Roy Frostig
15bc62204e jaxpr: support dropped assignment 2020-06-09 13:47:17 -07:00
Roy Frostig
bd3cab9768 update jaxpr doc to reflect lax.switch and indexed cond 2020-06-03 22:19:15 -07:00
Tom Hennigan
6124f703af
Add support for buffer donation in jit and pmap. (#2936)
For a computation of the form:

    >>> f = lambda x: x ** 2
    >>> f = jax.jit(f)
    >>> while run:
    ...   x = f(x)

JAX must currently always have two copies of `x` in device memory since there
is no reliable way in Python to determine whether there will be future uses of
`x`. This causes two classes of problem:

  1. Users at the limit of available device are constrained by the additional
     copy of their parameters and other state while they typically only require
     one copy. This typically frees 100M+ of device memory and is a critical
     optimization for larger models to match state of the art performance in
     other frameworks.

  2. This constant alloc/free of the input/output buffers can cause memory
     fragmentation on some platforms (although having a reusing allocator and
     limiting run-ahead may be a better solution for this problem).

We propose fixing this by using input/output aliasing as supported by XLA. We
will support this in JAX by allowing certain arguments of jit/pmap decorated
functions to be donated and reused as outputs:

    >>> f = lambda x: x ** 2
    >>> f = jit(f, donate_argnums=0)
    >>> while run:
    ...   x = f(x)

JAX will determine that the donated input `x` can alias with the output of the
function and it will instruct XLA it _must_ write the result to this buffer.

If a user tries to reuse a buffer after it has been donated they get an error
that the buffer is invalid:

    >>> y = f(x)
    >>> jax.device_get(x)
    ...
    RuntimeError: Invalid argument: CopyToHostAsync() called on invalid buffer.

The semantics of `donate_argnums` follows that of `static_argnums`, namely that
it identifies positional arguments to the computation that are to be donated
to the computation and used as part of the output.

One feature that is also enabled by this is invalidating buffers that should
only be used once, for example PRNGKeys:

    >>> @partial(jit, donate_argnums=0)
    ... def move(x):
    ...   # Do something complex enough for JAX to just optimize it away.
    ...   return tree_map(lambda x: x + x - x, x)

    >>> def safe_eager_uniform(key, *a, **k):
    ...   assert hasattr(key, 'device_buffer'), "random must run eagerly"
    ...   key = move(key)
    ...   return jax.random.uniform(key, *a, **k)

This is not a complete answer to random safety since it is still possible to
reuse a key as part of a traced computation, however it can be used to support
this feature (somewhat inefficiently) in eager mode.
2020-05-31 15:00:16 -07:00
Roy Frostig
916953ace8 update example in jaxpr doc 2020-05-29 16:57:40 -07:00
Roy Frostig
de03c99b52 update jaxpr doc and tests with single-operand cond 2020-05-13 21:14:41 -07:00
James Bradbury
f60184e12e
Support axis_index_groups in allreduce collectives (#2382)
* support replica groups in allreduce collectives

* add test and fix jaxpr in docs

* switch from XLA replica IDs to JAX axis indices

* fix psum transpose rule

* test other nesting order + imperfect nesting

* update jaxpr.rst

* handle None case

* add note+check that groups  cover the index space

* switch split_axis assert to NotImplementedError

* update CHANGELOG
2020-05-08 14:00:34 -07:00
Matthew Johnson
3cd409ee88
add optional 'forward' argument to lax.scan (#2921)
* add optional 'forward' argument to lax.scan

* switch to reverse; revise disable-jit case

* fix jaxpr.rst

* fix loops.py

Co-authored-by: James Bradbury <jekbradbury@gmail.com>
2020-05-04 19:44:22 -07:00
Roman Ring
525235d8c9
Fix a codeblock in the "understanding jaxpr" doc. (#2942)
This fixes an issue where the codeblock didn't render properly on the website.
2020-05-04 13:20:21 +03:00