rolling forward shard_map transpose fixes

The new efficient-transpose path, enabled by setting check_rep=True in the shard_map call, had kept working. But the change inadvertently broke the check_rep=False path. And because most tests set check_rep=True, we didn't notice it in the tests!

The issue was that with check_rep=False, we need the shard_map transpose rule to insert psums corresponding to in_specs with fan-out, and correspondingly insert division for out_specs with fan-in-consensus. (With the new check_rep=True path that this change adds, those extra operations aren't necessary as the body itself transposes correctly.) But the PR accidentally removed those!

The fix was simple: just track whether we've applied the efficient-transpose-body-rewrite (i.e. whether we're in the new body-is-transposable path or old need-extra-operations path) by adding a boolean parameter `rewrite` to the shard_map primitive, and if the rewrite hasn't been applied then include the explicit psum/div operations in the transpose rule.

Reverts 8a04dfd830ff89f46e1fe3e866ee4fb2da9c90aa

PiperOrigin-RevId: 561805840
This commit is contained in:
Matthew Johnson 2023-08-31 17:30:34 -07:00 committed by jax authors
parent c38f67043c
commit 70b58bbd30
10 changed files with 1399 additions and 126 deletions

View File

@ -0,0 +1,515 @@
# Efficient transposition of replication-inducing collectives
*mattjj@*, *dougalm@*
*August 2023*
## Motivation
We have an efficiency problem in automatically transposing `shmap`s containing
certain collectives. The issue arises with `psum` and `all_gather`, specifically
when the output of the collective is returned to the caller as an unmapped
output. And it's not an edge case: for example, it arises when applying `grad`
to a `shmap`-based batch data parallel neural network loss function which uses
`psum` to compute the total loss.
We've known about this problem for some time. An analogous issue exists with
`pmap`, though it's been worked around by keeping `grad` inside `pmap` rather than
outside. A primary goal of the incomplete avals-with-names work was to address a
version of this transpose efficiency problem. This doc draws on those ideas,
while extending and revising them to handle more cases and to be much easier to
land. Indeed the solution proposed here only affects the `shmap` implementation.
The rest of the system need not be changed (yet).
The main purpose of this doc is to define this transpose efficiency problem and
propose an easy-to-land solution.
This doc is not about:
* logical axis names on arrays (the only axis names here are just like in
`shmap` and OG `pmap`);
* changing autodiff semantics (all the numbers and (non)errors are staying the
same, we're just making things more efficient);
* allowing user code to reflect on any new information, or really affecting user
code at all.
## Problem: efficient transpose of `psum` or `all_gather` depends on whether cotangents are invariant across devices
Consider this semi-realistic example, meant to resemble a replicated-parameter
batch data parallel loss function:
```python
devices = jax.devices() # 8 devices
@partial(shmap, mesh=Mesh(devices, ('batch',)),
in_specs=(P(None, None), P('batch', None)),
out_specs=P())
def loss(params, batch):
inputs, targets = batch
predictions = predict(params, inputs)
local_loss = jnp.mean(jnp.sum(predictions - targets, -1))
global_loss = lax.pmean(local_loss, 'batch'))
return global_loss
```
Notice the `out_specs=P()`, which indicates an unmapped output. If you're not
familiar with the notion of unmapped outputs, see the appendix at the bottom of
this document.
Most of the details in the `loss` example aren't important. All that matters for
our purposes is that we're applying `psum` (or rather `pmean = lambda x, name:
psum(x, name) / psum(1, name)`) at the end. So a distilled version looks like
this:
```python
# Example 1: shmap involving psum and unmapped output with inefficient transpose
f1 = shmap(lambda x: psum(g(x), 'i'),
in_specs=P('i'), out_specs=P())
```
We even simplified notation by suppressing the `mesh` argument. In the examples to
follow it can be inferred from context.
What does the transpose look like? Writing `t` to mean function transpose, we
could evaluate `t(f1)(ybar)` for any `ybar` efficiently by applying the function
`¿f1_transpose?` below:
```python
# An efficient "transpose" of Example 1 (but don't transpose this again!)
¿f1_transpose? = shmap(t(g), in_specs=P(), out_specs=P('i'))
```
But that's not the transpose we currently get as t(f1).
Instead, the current recipe for transposition is roughly that we switch
`in_specs` and `out_specs`, do some division rescaling for unmapped outputs, and
transpose the body. Because `psum` is its own transpose (as an all-reduce sum),
we end up producing this transpose:
```python
# The transpose we currently get for Example 1 (which is fine to transpose again)
t(f1) = shmap(lambda ybar: t(g)(psum(ybar / 8, 'i')),
in_specs=P(), out_specs=P('i'))
```
This transpose gets the numbers right, but it's wasteful. We know statically
from the transpose's `in_specs=P()` that `ybar` has the same value for each function
instance, i.e. that its value is device-invariant for devices along the mesh
axis named `i`, and yet we apply a `psum` to it! That uses expensive communication
just to multiply the value on each device by 8. (Here 8 refers to the size of
axis i. The division by 8 comes from the original function's `out_specs=P()`; it
and the trivial `psum` basically cancel each other out.)
What are we doing wrong? We're not exploiting the fact that cotangents `ybar`
corresponding to `f1`'s unmapped outputs are guaranteed to be device-invariant;
instead, we're defensively `psum`ming them as if they weren't because `psum`'s
transpose can't be sure given the local information it has. Sometimes the `psum`
is necessary, as in transposing `f2` with respect to its first argument:
```python
# Example 2: shmap involving psum and *mapped* output with efficient transpose
f2 = shmap(lambda x, y: psum(g(x), 'i') * y,
in_specs=(P('i'), P('i')), out_specs=P('i'))
# The transpose we currently get for Example 2 is efficient
t(f2, 0) = shmap(lambda y, zbar: t(g)(psum(zbar * y, 'i')),
in_specs=(P('i'), P('i')), out_specs=P('i'))
```
Intuitively, if our transpose machinery could tell the difference between
Example 1 and Example 2, we could do better by avoiding the psum and division
where possible.
The inefficient examples can be even smaller. Consider transposing this cursed
identity function:
```python
# Example 3: cursed identity
cursed_identity = shmap(lambda x: x, P(), P())
# Currently we get these inefficient transposes
t(cursed_identity) = shmap(lambda x: psum(x / 8, 'i'), P(), P())
t(t(cursed_identity)) = shmap(lambda x: psum(psum(x / 8 / 8, 'i'), 'i')), P(), P())
...
```
It keeps getting bigger the more we transpose. How embarrassing!
And `psum` isn't the only culprit. Something analogous holds true for
`all_gather`:
```python
# Example 4: all_gather to an unmapped output
f4 = shmap(lambda x: all_gather(x, 'i'), P('i'), P())
# Currently we get this inefficient transpose
t(f4) = shmap(lambda ybar: psum_scatter(ybar / 8, 'i'), P(), P('i'))
```
This program is a bit artificial. Why do an `all_gather` and feed the result into
an unmapped output, rather than skipping the `all_gather` in the body and just
using `out_specs=P('i')` to collect the results? But even though it's cooked-up,
this example nevertheless exhibits a transpose which unnecessarily performs
communication (we could have just performed a non-communicating slice),
analogous to Example 1 for `psum`.
Also analogously to the `psum` examples, the defensive `psum_scatter` is
necessary in some cases:
```python
# Example 5: all_gather to a mapped output
f5 = shmap(lambda x, y: all_gather(x, 'i') * y,
in_specs=(P('i'), P('i')), out_specs=P('i'))
# Currently we get this efficient transpose
t(f5, 0) = shmap(lambda y, zbar: psum_scatter(zbar * y, 'i'),
in_specs=(P('i'), P('i')), out_specs=P('i'))
```
So how do we avoid these inefficient transposes?
## Solutions
Here are two solution ideas. They aren't mutually exclusive. But (spoilers) the
second one is better, and it's all we need.
### Partial solution "P-sum": build the ability to express a `psum` into `out_specs`
This solution is a bit of a strawperson because it would offer only an awkward
way to write programs. And it wouldn't even fix everything! But it's worth
considering, if only to motivate a more complete solution.
Example 4 above is artificial because we could have just used `out_specs` instead
of an `all_gather` in the body:
```python
# Example 4 again
f4 = shmap(lambda x: all_gather(x, 'i'), P('i'), P())
# Why didn't we just write it like this?
f4_better = shmap(lambda x: x, P('i'), P('i'))
```
The `f4_better` version doesn't have any transposition problems, since the
transpose problems arise from collectives in the body.
Analogously, we could fix Example 1 by extending `out_specs` so that they can
express summing:
```python
# Example 1 again
f1 = shmap(lambda x: psum(g(x), 'i'),
in_specs=P('i'), out_specs=P())
# What if we could write an output sum like this?
f1_better = shmap(g, in_specs=P('i'), out_specs=P(sum='i')) # sum='i' means sum over that axis
# Then it could transpose like this:
t(f1_better) = shmap(t(g), in_specs=P(), out_specs=P('i'))
t(t(f1_better)) = shmap(t(t(g)), in_specs=P('i'), P(sum='i'))
```
So offering `psum`s built into `out_specs` fixes the transpose problem of
Example 1. But it doesn't fully fix the cursed identity transpose in Example 3:
```python
# Example 3 again
cursed_identity = shmap(lambda x: x, P(), P())
# How it would transpose with the P-sum partial solution:
t(cursed_identity) = shmap(lambda x: x / 8, P(), P(sum='i'))
t(t(cursed_identity)) = shmap(lambda x: x / 8, P(), P(sum='i'))
```
It's an improvement since the program doesn't continue to get bigger as we keep
transposing, but we're still doing wasteful communication.
### Full solution: statically track device-varying vs device-invariant intermediates, plus new primitives
This solution has two components:
1. track when values are guaranteed to be device-invariant vs device-varying
over particular mesh axes, and
2. decompose `psum` into a two-step process, introducing a new `pbroadcast`
primitive, and introduce new primitives for `all_gather` and its transposes.
Morally, the tracking of device-invariant vs device-varying information is a
type-level consideration. But for the expedience of our first implementation, we
don't need to literally add the information to abstract values or jaxpr types.
Before we get to implementation, we'll first introduce the idea using types.
Also to follow is a discussion of making the user API convenient and backward
compatible. But to first introduce the idea, we'll ignore convenience and
instead write code that is as explicit as possible.
#### Tracking device invariance in avals (a.k.a. avals-with-names, revived)
We can sometimes tell from static information alone that the values of some
intermediate variables in the body of a `shmap` are guaranteed to be invariant
along a mesh axis, in the sense that the function instances (and their
corresponding devices) along the mesh axis must all be computing with the same
value. We'll call such values device-invariant. For values that are not
device-invariant, we'll say they're device-varying, though really we mean
potentially device-varying from the point of view of the type system.
To encode device variance in types, we'll extend the syntax of types for arrays.
We'll write things like `x:f32[3,4]{i}` to indicate that `x` is (potentially)
device-varying along mesh axis `i` (and device-invariant over any other mesh
axes of the `shmap`). More generally, we'll say the grammar for array type
syntax is something like
```
shaped_array ::= <dtype>[<int_literal>, ...]<device_variance_type>
device_variance_type ::= {<axis_name>, ...}
```
We'll also update the typing rules to handle device variance types:
* for first-order primitives other than collectives
- for multi-arity primitives, the operand device variance types must be equal
where shapes must be equal, e.g. `mul x:f32[s1]{r1} y:f32[s2][r2]` requires
`r1 == r2` in addition to `s1 == s2`
- the output device variance type must be the same as the operand(s)
* for higher-order primitives
- we just instantiate any type variables including the device variance type
(and checking types for equality checks their device variance types are
equal)
- (when performing type inference, e.g. for branches of a `cond`, we take the
union of the sets of axis names in device variance types)
* for first-order collectives
- a collective can either accept a device-varying or device-invariant input
(along a mesh axis corresponding to its axis name parameter); it's an error
to pass a device-invariant operand to a collective which accepts
device-varying operands and vice-versa
- a collective can either produce a device-varying or device-invariant output
- see the table below
As a side benefit, whatever logic implements this type checking can subsume
`shmap`'s "static analysis" check for whether a `shmap` body function is
compatible with any unmapped `out_specs`.
Here's a table summarizing the device variance typing for collective primitives:
| Name | Device variance type | Example | Lowers to HLO | Transpose |
| --- | --- | --- | --- | --- |
| `psum2` | `Varying -> Invariant` | `y:f32[3]{j} = psum(x:f32[3]{i,j}, axis='i')` | `AllReduceSum` (communication) | `pbroadcast` |
| `pbroadcast` | `Invariant -> Varying` | `y:f32[3]{i} = pbroadcast(x:f32[3], 'i')` | no-op (no communication) | `psum` |
| `all_to_all` | `Varying -> Varying` | `y:f32[16]{i} = all_to_all(x:f32[16]{i}, 'i', 0, 0)` `AllToAll` (communication) | `all_to_all` |
| `axis_index` | `() -> Varying` | `idx:i32[]{i} = axis_index('i')` | `ReplicaId` and some arithmetic (no communication) | n/a |
| `psum_scatter` | `Varying -> Varying` | `y:f32[2]{i} = psum_scatter(x:f32[16]{i}, 'i')` | `ReduceScatterSum` (communication) | `all_gather` |
| `all_gather` | `Varying -> Varying` | `y:f32[16]{i} = all_gather(x:f32[2]{i}, 'i')` | `AllGather` (communication) | `psum_scatter` |
| `pscatter` | `Invariant -> Varying` | `y:f32[2]{i} = pscatter(x:f32[16], 'i')` | `lambda x: x[axis_index('i'), None]` (no communication) | `all_gather_invariant` |
| `all_gather_invariant` | `Varying -> Invariant` | `y:f32[16] = all_gather_invariant(x:f32[2]{i}, 'i')` | `AllGather` (communication) | `pscatter` |
There are some surprising things here!
* We introduced several new primitives, including
- `pbroadcast`, which interestingly lowers to a no-op
- `all_gather_invariant`, which lowers to the same thing as `all_gather` but
has a different device variance type (essentially `all_gather` has a
`pbroadcast` fused into it, whereas `all_gather_invariant` does not)
- `pscatter` which is the dual (transpose) of `all_gather_invariant`
* all_gather has a device-varying result
Intuitively, the reason to introduce `pbroadcast` (other than to make the typing
rules work) is so that `psum` can transpose to a physical no-op. The reason we
need `all_gather` to have a device-varying result is so that we can transpose it
to `psum_scatter`; if we instead left it with a device-invariant result, we
might need a downstream `pbroadcast`, and that composition would transpose to an
inefficient `psum` followed by slicing / `pscatter`. So instead we have a
`pbroadcast` "fused into" the `all_gather`, thus allowing for an efficient
transpose into `psum_scatter`. We provide `all_gather_invariant` and its
transpose `pscatter` mainly for completeness; it's unlikely users will need it
(it corresponds to the situation in Example 4, which is easy to write
differently using `out_specs`).
Interestingly, the `psum` and `pbroadcast` transpose pair correspond to the
`psum_idrev` and `id_psumrev` that users introduced while training LLMs with
`pmap`.
#### How this system solves the inefficient transpose examples
Consider again the simplified motivating example:
```python
# Example 1 again
f1 = shmap(lambda x: psum(g(x), 'i'),
in_specs=P('i'), out_specs=P())
# Example 1 with intermediate device variance types annotated
@partial(shmap, in_specs=P('i'), out_specs=P())
def f1(x: f32[3,4]{i}):
w:f32[]{i} = g(x)
y:f32[]{} = psum(w, 'i')
return y
```
With these new rules, the transpose is:
```python
# Example 1 transpose using device variance types (go ahead and transpose this again!)
t(f1) = shmap(lambda ybar: t(g)(pbroadcast(ybar, 'i')),
in_specs=P(), out_specs=P('i'))
# Example 1 transpose with intermediate device variance types annotated
@partial(shmap, in_specs=P('i'), out_specs=P())
def f1_transpose(ybar: f32[]):
wbar:f32[]{i} = pbroadcast(ybar, 'i')
xbar:f32[3,4]{i} = transpose(g)(wbar)
return xbar
```
where evaluating the `pbroadcast` application involves no communication or FLOPs
at all; it's a no-op. Notice that if we keep transposing the body does not grow
in size; indeed `t(t(f1)) == f1`. Efficiency achieved!
And we wouldn't mess up the other examples either, so long as we `pbroadcast` to
make the types check where needed:
```python
# Example 2 rewritten with explicit pbroadcast
f2 = shmap(lambda x, y: pbroadcast(psum(g(x), 'i'), 'i') * y,
in_specs=(P('i'), P('i')), out_specs=P('i'))
# Example 2 transpose using device variance types
t(f2, 0) = shmap(lambda y, zbar: t(g)(pbroadcast(psum(zbar * y, 'i'), 'i')),
in_specs=(P('i'), P('i')), out_specs=P('i'))
# Example 3 again
cursed_identity = shmap(lambda x: x, P(), P())
# Notice here the body is `f32[...] -> f32[...]`, i.e. no device varying type.
# Example 3 transpose using device variance types
t(cursed_identity) = shmap(lambda x: x, P(), P())
t(t(cursed_identity)) = shmap(lambda x: x, P(), P())
```
Intuitively, in Example 1 we now only have "half the original psum", whereas in
Example 2 we get both "halves". For Example 3 we never need any operations in
the body at all.
For the `all_gather` examples, Example 4 would need to use
`all_reduce_invariant` to have an efficient transpose (though it'd be better to
instead use `out_specs` instead of the collective in the body):
```python
# Example 4 rewritten with explicit all_reduce_invariant
f4 = shmap(lambda x: all_gather_invariant(x, 'i'), P('i'), P())
# Example 4 with intermediate device variance types annotated
@partial(shmap, P('i'), P())
def f4(x:f32[1]{i}):
y:f32[8]{} = all_gather_invariant(x, 'i')
return y
# Example 4 transpose with intermediate device variance types annotated
@partial(shmap, in_specs=P(), out_specs=P('i'))
def f4_transpose(ybar:f32[8]):
xbar:f32[1]{i} = pscatter(ybar, 'i')
return xbar
```
For Example 5, using the device-varying `all_gather` works as we'd want:
```python
# Example 5 with intermediate device variance types annotated
@partial(shmap, in_specs=(P('i'), P('i')), out_specs=P('i'))
def f5(x:f32[1]{i}, y:f32[8]{i}):
z:f32[8]{i} = all_gather(x, 'i')
w:f32[8]{i} = z * y
return w
# Transpose with respect to first argument
@partial(shmap, in_specs=(P('i'), P('i')), out_specs=P('i'))
def f5_transpose(y:f32[8]{i}, wbar:f32[8]{i}):
zbar:f32[8]{i} = wbar * y
xbar:f32[1]{i} = psum_scatter(zbar, 'i')
return xbar
```
### How to make the API convenient for users (and backward compatible)
But what user wants to write `pbroadcast`s? And what developer wants to break
lots of existing user code involving `psum`s which are not fed into unmapped
outputs? Not me!
Instead we can automatically insert the `pbroadcast`s. It's a bit analogous to how
we do automatic rank promotion at the `jax.numpy` layer, inserting broadcasts to
avoid rank mismatch errors in binary operators. But it's much simpler since we
don't need to contend with shape tuples. The typical rule is: whenever we see a
multi-arity operation where the operands disagree in their device variance
types, take the union of operands' device variance types' axis name sets and
insert `pbroadcast`s to lift each operand to the resulting device variance type.
Automatically inserting `pbroadcast`s just before they're needed may mean we apply
the same `pbroadcast` to the same operand multiple times, creating common
subexpressions. When we transpose, those could turn into a sum-of-`psum`s rather
than a `psum`-of-sum. We'll rely on the compiler to clean that up as appropriate.
If it's a problem then we could add some simple memoization to the
`pbroadcast`-insertion pass.
The user API for `all_gather` will mean `all_gather_p` by default (not
`all_gather_invariant_p`), covering the common case and meaning no `pbroadcast`s
must be inserted.
We can provide an option on `shmap` to disable this automatic insertion of
`pbroadcast`s, in which case it'll be up to the user to ensure type-correctness.
This explicit option may be appealing to some who want to be explicit about
where the `psum`s occur in the backward pass.
### How to implement the solution
The key to making the implementation lightweight is that **we aren't going to
add these types to avals or jaxprs**. At least, not at first. That can be
expensive because it requires updating the rest of JAX, e.g. all consumers of
avals and jaxprs may need to handle the new types. We're not falling for that
again!
Instead we're going to keep these extended types as metadata internal to
`shmap`, just like the current "replication checking for `out_specs`" machinery
is internal to `shmap`. Indeed this solution amounts to a relatively small
extension to that existing machinery: it was already tracking the same
information; now we're just adding the `pbroadcast`s.
We have at least two options for where to perform the `pbroadcast` insertion:
1. just before transposition, in the transpose rule, where we have a jaxpr of
the computation to be transposed;
2. in every `shmap` body, whether eagerly executed or staged out, like the
current "replication checking for `out_specs`" machinery.
The former may end up being easier since we only have to handle the jaxpr case,
and only linear primitives. But we'll start by trying the latter so the
implementation here is a strict revision/extension to the existing
replication-checking logic.
## Appendix: defining and motivating maps with unmapped inputs and outputs
For concreteness, we'll mostly focus on `shmap`, though these same ideas apply
to e.g. `pmap` and probably `xmap`.
An argument/input is _unmapped_ along a mesh axis when the corresponding entry
of `in_specs` doesn't mention that mesh axis's name. Logically it means that
each function instance along that mesh axis gets the same value for the
argument. To the caller, each operand is sliced according to the mesh axes over
which the operand is mapped, whereas there is no slicing for mesh axes over
which the operand is unmapped.
An output is _unmapped_ along a mesh axis when the corresponding entry of
`out_specs` doesn't mention that mesh axis's name. Logically it means each
function instance along that mesh axis must return the same value. To the
caller, each result of the `shmap` is formed by concatenating the return values
of every function instance along which the outputs are mapped, whereas for mesh
axes over which the output is unmapped only one copy of the value is used.
See [the `shmap`
JEP](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html) for examples
of unmapped inputs and outputs. For comparison, in `vmap` unmapped
inputs/outputs are indicated by using `in_axes` / `out_axes` of `None` (rather
than an `int`).
Here are reasons we like unmapped inputs and outputs for `shmap`:
* **Same expressiveness as `pjit`.** Anything `pjit` can do, the `shmap` escape
hatch should be able to do too. Or else we'd have a lacking escape hatch! If
we didnt have unmapped outputs in `shmap` then we couldn't express the same
batch-parallel loss function computations as `pjit`.
* **Closed-over inputs.** Closed-over inputs essentially correspond to unmapped
inputs, and...
* **Closure under transposition.** Once we have unmapped inputs, it's natural to
be able to transpose to unmapped outputs.
So unmapped outputs are both canonical and useful!

View File

@ -48,6 +48,7 @@ Then create a pull request that adds a file named
12049: Type Annotation Roadmap for JAX <12049-type-annotations>
14273: `shard_map` (`shmap`) for simple per-device code <14273-shard-map>
15856: `jax.extend`, an extensions module <15856-jex>
17111: Efficient transposition of `shard_map` (and other maps) <17111-shmap-transpose>
Several early JEPs were converted in hindsight from other documentation,

View File

@ -1084,8 +1084,7 @@ def _why_alive_container_info(container, obj_id) -> str:
@contextmanager
def new_main(trace_type: type[Trace],
dynamic: bool = False,
def new_main(trace_type: type[Trace], dynamic: bool = False,
**payload) -> Generator[MainTrace, None, None]:
# See comments in https://github.com/google/jax/pull/3370
stack = thread_local_state.trace_state.trace_stack
@ -1111,6 +1110,20 @@ def new_main(trace_type: type[Trace],
leaked_tracers = maybe_find_leaked_tracers(t())
if leaked_tracers: raise leaked_tracer_error("trace", t(), leaked_tracers)
@contextmanager
def new_dynamic(level: int) -> Generator[None, None, None]:
stack = thread_local_state.trace_state.trace_stack
prev_dynamic, stack.dynamic = stack.dynamic, stack.stack[level]
_update_thread_local_jit_state(stack.dynamic)
try:
yield
finally:
stack.dynamic = prev_dynamic
_update_thread_local_jit_state(stack.dynamic)
def dynamic_level() -> int:
return thread_local_state.trace_state.trace_stack.dynamic.level
@contextmanager
def new_base_main(trace_type: type[Trace],
**payload) -> Generator[MainTrace, None, None]:

View File

@ -853,6 +853,7 @@ def _custom_vjp_call_jaxpr_jvp(
tangents_out = ad.custom_lin_p.bind(
*res, *args_dot, num_res=res_tree.num_leaves, bwd=bwd,
out_avals=avals_out, symbolic_zeros=symbolic_zeros)
tangents_out = map(lax.tie_p.bind, primals_out, tangents_out)
tangents_out = map(ad.recast_to_float0, primals_out, tangents_out)
return primals_out, tangents_out
ad.primitive_jvps[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_jvp

View File

@ -403,6 +403,7 @@ class JVPTrace(Trace):
tangents_out = custom_lin_p.bind(
*res, *tangents_in, num_res=res_tree.num_leaves, bwd=bwd,
out_avals=avals_out, symbolic_zeros=symbolic_zeros)
tangents_out = map(jax._src.lax.lax.tie_p.bind, primals_out, tangents_out)
tangents_out = map(recast_to_float0, primals_out, tangents_out)
return map(partial(JVPTracer, self), primals_out, tangents_out)

View File

@ -4966,6 +4966,17 @@ def _empty_lower(ctx, *, dtype):
mlir.register_lowering(empty_p, _empty_lower)
tie_p = core.Primitive('tie')
tie_p.def_impl(lambda x, y: y)
tie_p.def_abstract_eval(lambda x, y: y)
mlir.register_lowering(tie_p, lambda ctx, x, y: [y])
ad.primitive_jvps[tie_p] = \
lambda primals, tangents: (tie_p.bind(*primals), tangents[-1])
ad.primitive_transposes[tie_p] = lambda ct, x, _: [None, ct]
pe.def_trivial_padding(tie_p)
batching.defvectorized(tie_p)
class BIntRules:
@staticmethod
def physical_element_aval(dtype) -> core.ShapedArray:

View File

@ -812,6 +812,7 @@ batching.axis_primitive_batchers[psum_p] = \
partial(_batched_reduction_collective, psum_p, lambda v, axis_size: axis_size * v)
core.axis_substitution_rules[psum_p] = partial(_subst_all_names_in_param, 'axes')
# We set a special bind rule for psum so that psum(1, 'i') can be evaluated at
# tracing time.
@psum_p.def_custom_bind

View File

@ -1465,6 +1465,8 @@ tf_not_yet_impl = [
"pmin",
"ppermute",
"psum",
"psum2",
"pbroadcast",
"pmax",
"pgather",
"reduce_scatter",
@ -3449,6 +3451,9 @@ def _reduce_precision(x, *, exponent_bits, mantissa_bits):
tf_impl[lax.reduce_precision_p] = _reduce_precision
tf_impl[lax_internal.tie_p] = lambda x, y: y
def _register_checkpoint_pytrees():
"""Registers TF custom container types as pytrees."""
m = tf.Module()

File diff suppressed because it is too large Load Diff

View File

@ -37,6 +37,7 @@ from jax._src import xla_bridge
from jax._src.util import safe_zip, safe_map, partition_list, merge_lists
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src import linear_util as lu
from jax._src import tree_util
import jax.numpy as jnp
@ -113,9 +114,12 @@ class ShardMapTest(jtu.JaxTestCase):
mesh, a, _ = create_inputs(P('z', ('x', 'y')), P(None, None))
assert a.device_buffers[0].shape == (4, 2)
# NOTE(mattjj): to use out_specs=P(None, ('x', 'y')), we need to use
# all_gather_invariant primitive, which differs in its output replication
# type compared to all_gather.
@jax.jit
@partial(shard_map, mesh=mesh,
in_specs=(P('z', ('x', 'y')),), out_specs=P(None, ('x', 'y')))
in_specs=(P('z', ('x', 'y')),), out_specs=P('z', ('x', 'y')))
def fwd(a):
return lax.all_gather(a, 'z', axis=0, tiled=True)
@ -559,8 +563,7 @@ class ShardMapTest(jtu.JaxTestCase):
def test_partial_eval_custom_axis_env(self):
mesh = Mesh(jax.devices(), ('i',))
@partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i'),
check_rep=False) # check_rep=False b/c no scan rep rule yet
@partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i'))
def f(_):
_, idx = jax.lax.scan(lambda _, __: (None, jax.lax.axis_index('i')),
None, None, length=1)
@ -861,6 +864,274 @@ class ShardMapTest(jtu.JaxTestCase):
self.assertIn("\"[('i',)]\"", mhlo_str)
self.assertIn("-> (tensor<1xf32> {jax.result_info = \"[('i',)]\"})", mhlo_str)
def test_rewrite_process_call(self):
def f(x):
return core.call_p.bind(lu.wrap_init(lambda x: [2. * x]), x)[0] * x
mesh = jtu.create_global_mesh((4,), ('x',))
g = shard_map(f, mesh, in_specs=(P('x'),), out_specs=P('x'))
x = jnp.arange(4.)
y = jax.jit(g)(x) # eager requires shmap to have ShardMapTrace.process_call
self.assertAllClose(y, 2 * x * x, check_dtypes=True)
def test_rewrite_post_process_call(self):
# We shouldn't hit post_process_call here because of RewriteTrace's dynamic
# behavior (i.e. no data dependence).
mesh = jtu.create_global_mesh((4,), ('x',))
@jax.jit
@partial(shard_map, mesh=mesh, in_specs=(P('x'),), out_specs=P('x'))
def f(x):
return core.call_p.bind(lu.wrap_init(lambda: [2. * x]))[0] * x
x = jnp.arange(4.)
y = f(x)
self.assertAllClose(y, 2 * x * x, check_dtypes=True)
def test_rewrite_process_custom_jvp_call(self):
@jax.custom_jvp
def foo(x):
return 2. * x
@foo.defjvp
def foo_jvp(primals, tangents):
(x,), (x_dot,) = primals, tangents
return foo(x), 2. * x_dot
mesh = jtu.create_global_mesh((4,), ('x',))
g = shard_map(lambda x: foo(x) * x, mesh,
in_specs=(P('x'),), out_specs=P('x'))
x = jnp.arange(4.)
y = jax.jit(g)(x)
self.assertAllClose(y, 2 * x * x, check_dtypes=True)
y2, y_dot = jax.jvp(jax.jit(g), (x,), (3 * x,))
self.assertAllClose(y2, 2 * x * x, check_dtypes=True)
self.assertAllClose(y_dot, 2 * 2 * 3 * x * x, check_dtypes=True)
def test_rewrite_process_custom_vjp_call(self):
@jax.custom_vjp
def foo(x):
return 2. * x
def foo_fwd(x):
return foo(x), None
def foo_bwd(_, y_bar):
return 2. * y_bar,
foo.defvjp(foo_fwd, foo_bwd)
mesh = jtu.create_global_mesh((4,), ('x',))
g = shard_map(lambda x: foo(x) * x, mesh,
in_specs=(P('x'),), out_specs=P('x'))
x = jnp.arange(4.)
y = jax.jit(g)(x)
self.assertAllClose(y, 2 * x * x, check_dtypes=True)
y_, x_bar = jax.value_and_grad(lambda x: jax.jit(g)(x).sum())(x)
self.assertAllClose(y_, (2 * x * x).sum(), check_dtypes=True)
self.assertAllClose(x_bar, 2 * 2 * x, check_dtypes=True)
def test_rewrite_process_custom_vjp_call_match_more_replicated(self):
@jax.custom_vjp
def foo(x):
return 2. * x
def foo_fwd(x):
return foo(x), None
def foo_bwd(_, y_bar):
return jnp.ones_like(y_bar), # diff! more replicated than primal/tangent
foo.defvjp(foo_fwd, foo_bwd)
mesh = jtu.create_global_mesh((4,), ('x',))
g = shard_map(lambda x: foo(x) * x, mesh,
in_specs=(P('x'),), out_specs=P('x'))
x = jnp.arange(4.)
y = jax.jit(g)(x)
self.assertAllClose(y, 2 * x * x, check_dtypes=True)
y_, x_bar = jax.value_and_grad(lambda x: jax.jit(g)(x).sum())(x)
self.assertAllClose(y_, (2 * x * x).sum(), check_dtypes=True)
self.assertAllClose(x_bar, jnp.ones_like(x) + 2 * x, check_dtypes=True)
def test_rewrite_process_custom_vjp_call_match_less_replicated(self):
@jax.custom_vjp
def foo(x, y):
del y
return 2. * x
def foo_fwd(x, y):
return foo(x, y), y
def foo_bwd(y, _):
return y, None # diff! x_bar less replicated than primal/tangent
foo.defvjp(foo_fwd, foo_bwd)
mesh = jtu.create_global_mesh((4,), ('x',))
g = shard_map(lambda x, y: foo(x, y) * y, mesh,
in_specs=(P(), P('x')), out_specs=P('x'))
x = jnp.arange(4.)
y = jnp.arange(4 * 4.)
z = jax.jit(g)(x, y)
self.assertAllClose(z, 2 * jnp.tile(x, (4,)) * y, check_dtypes=False)
z_, x_bar = jax.value_and_grad(lambda x, y: jax.jit(g)(x, y).sum())(x, y)
self.assertAllClose(z.sum(), z_, check_dtypes=False)
self.assertAllClose(x_bar, jnp.arange(16).reshape(4, 4).sum(0),
check_dtypes=False)
def test_rewrite_custom_vjp_call_jaxpr(self):
@jax.custom_vjp
def foo(x):
return 2. * x
def foo_fwd(x):
return foo(x), None
def foo_bwd(_, y_bar):
return 2. * y_bar,
foo.defvjp(foo_fwd, foo_bwd)
def foo_scan(x):
y, _ = jax.lax.scan(lambda x, _: (foo(x), None), x, None, length=1)
return y
mesh = jtu.create_global_mesh((4,), ('x',))
g = shard_map(lambda x: foo_scan(x) * x, mesh,
in_specs=(P('x'),), out_specs=P('x'))
x = jnp.arange(4.)
y = jax.jit(g)(x)
self.assertAllClose(y, 2 * x * x, check_dtypes=True)
y_, x_bar = jax.value_and_grad(lambda x: jax.jit(g)(x).sum())(x)
self.assertAllClose(y_, (2 * x * x).sum(), check_dtypes=True)
self.assertAllClose(x_bar, 2 * 2 * x, check_dtypes=True)
def test_transpose_identity(self):
mesh = jtu.create_global_mesh((4,), ('x',))
@partial(shard_map, mesh=mesh, in_specs=P(), out_specs=P())
def f(x):
return x
jaxpr = jax.make_jaxpr(jax.vjp(f, 1.)[1])(1.)
e, = jaxpr.jaxpr.eqns
self.assertEmpty(e.params['jaxpr'].eqns)
jaxpr = jax.make_jaxpr(jax.vjp(jax.vjp(f, 1.)[1], 1.)[1])((1.,))
e, = jaxpr.jaxpr.eqns
self.assertEmpty(e.params['jaxpr'].eqns)
@partial(shard_map, mesh=mesh, in_specs=P(), out_specs=P())
def g(x):
return jax.jit(lambda x: x)(x)
jaxpr = jax.make_jaxpr(jax.vjp(g, 1.)[1])(1.)
e, = jaxpr.jaxpr.eqns
e1, e2 = e.params['jaxpr'].eqns
self.assertEmpty(e1.outvars)
self.assertEmpty(e2.params['jaxpr'].eqns)
def test_fanout_specs_transpose_to_psum(self):
mesh = jtu.create_global_mesh((4,), ('x',))
@partial(shard_map, mesh=mesh, in_specs=P(), out_specs=P('x'))
def f(x):
return x
jaxpr = jax.make_jaxpr(jax.vjp(f, jnp.arange(1.))[1])(jnp.arange(4.))
e, = jaxpr.jaxpr.eqns
e2, = e.params['jaxpr'].eqns
self.assertEqual(str(e2.primitive), 'psum2')
self.assertEqual(e2.params['axes'], ('x',))
def test_fanin_psum_transposes_to_fanout(self):
mesh = jtu.create_global_mesh((4,), ('x',))
@partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P())
def f(x):
return jax.lax.psum(x, 'x')
jaxpr = jax.make_jaxpr(jax.vjp(f, jnp.arange(4.))[1])(jnp.array([1.]))
e, = jaxpr.jaxpr.eqns
e1, = e.params['jaxpr'].eqns
self.assertEqual(str(e1.primitive), 'pbroadcast')
def test_psum_with_implicit_fanout_self_transposes(self):
mesh = jtu.create_global_mesh((4,), ('x',))
@partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P('x'))
def f(x):
return jax.lax.psum(x, 'x')
jaxpr = jax.make_jaxpr(jax.vjp(f, jnp.arange(4.))[1])(jnp.arange(4.))
e, = jaxpr.jaxpr.eqns
e1, e2 = e.params['jaxpr'].eqns
self.assertEqual(str(e1.primitive), 'psum2')
self.assertEqual(str(e2.primitive), 'pbroadcast')
def test_rewrite_binops(self):
mesh = jtu.create_global_mesh((4,), ('x',))
@partial(shard_map, mesh=mesh, in_specs=(P(), P('x')), out_specs=P('x'))
def f(x, y):
return x * y
jaxpr = jax.make_jaxpr(f)(jnp.arange(1.), jnp.arange(4.))
e, = jaxpr.jaxpr.eqns
e = e.params['jaxpr'].eqns[0]
self.assertEqual(e.primitive.name, 'pbroadcast')
self.assertEqual(e.params['axes'], ('x',))
def test_rewrite_scan(self):
mesh = jtu.create_global_mesh((4,), ('x',))
@partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P('x'))
def f(x):
x, _ = jax.lax.scan(lambda x, _: (jax.lax.psum(x, 'x'), None), x, None,
length=2)
return x
jaxpr = jax.make_jaxpr(f)(jnp.arange(4.))
e, = jaxpr.jaxpr.eqns
e, = e.params['jaxpr'].eqns
e1, e2 = e.params['jaxpr'].eqns
self.assertEqual(e1.primitive.name, 'psum2')
self.assertEqual(e2.primitive.name, 'pbroadcast')
def test_check_rep_false_grads(self):
# This test is redundant with the systematic tests below, but it serves as a
# direct regression test for a bug.
mesh = jtu.create_global_mesh((4,), ('heads',))
def f(q, k, v):
def body(q, k, v):
return q * k[None, :] + v[None, :]
out = shard_map(body, mesh, check_rep=False,
in_specs=(q_spec, kv_spec, kv_spec,),
out_specs=q_spec)(q, k, v)
return out.sum()
q_spec = P('heads', None)
kv_spec = P(None)
q = jax.device_put(jnp.arange(32.).reshape(4, 8), jax.sharding.NamedSharding(mesh, q_spec))
k = jax.device_put(jnp.arange(8.), jax.sharding.NamedSharding(mesh, kv_spec))
v = jax.device_put(jnp.arange(8.), jax.sharding.NamedSharding(mesh, kv_spec))
jtu.check_grads(f, (q, k, v), order=1, modes=['rev'], rtol=1e-2)
class FunSpec(NamedTuple):
name: str
@ -1126,12 +1397,15 @@ class ShardMapSystematicTest(jtu.JaxTestCase):
self.assertAllClose(expected, out, check_dtypes=False)
@parameterized.named_parameters(
sample(config.FLAGS.jax_num_generated_cases, sample_shmap))
(name + f'_check_rep={check_rep}', *params, check_rep)
for (name, *params) in sample(config.FLAGS.jax_num_generated_cases, sample_shmap)
for check_rep in [True, False]
)
@jax.default_matmul_precision("float32")
def test_grads(self, fun, mesh, jit, in_specs, out_specs, args, _):
def test_grads(self, fun, mesh, jit, in_specs, out_specs, args, _, check_rep):
mesh = self.make_mesh(mesh)
args = map(jnp.array, args)
f = shard_map(fun, mesh, in_specs, out_specs)
f = shard_map(fun, mesh, in_specs, out_specs, check_rep=check_rep)
if jit:
f = jax.jit(f)
jtu.check_grads(f, args, order=2, atol=1e-2, rtol=1e-2)