mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
rolling forward shard_map transpose fixes
The new efficient-transpose path, enabled by setting check_rep=True in the shard_map call, had kept working. But the change inadvertently broke the check_rep=False path. And because most tests set check_rep=True, we didn't notice it in the tests! The issue was that with check_rep=False, we need the shard_map transpose rule to insert psums corresponding to in_specs with fan-out, and correspondingly insert division for out_specs with fan-in-consensus. (With the new check_rep=True path that this change adds, those extra operations aren't necessary as the body itself transposes correctly.) But the PR accidentally removed those! The fix was simple: just track whether we've applied the efficient-transpose-body-rewrite (i.e. whether we're in the new body-is-transposable path or old need-extra-operations path) by adding a boolean parameter `rewrite` to the shard_map primitive, and if the rewrite hasn't been applied then include the explicit psum/div operations in the transpose rule. Reverts 8a04dfd830ff89f46e1fe3e866ee4fb2da9c90aa PiperOrigin-RevId: 561805840
This commit is contained in:
parent
c38f67043c
commit
70b58bbd30
515
docs/jep/17111-shmap-transpose.md
Normal file
515
docs/jep/17111-shmap-transpose.md
Normal file
@ -0,0 +1,515 @@
|
||||
# Efficient transposition of replication-inducing collectives
|
||||
*mattjj@*, *dougalm@*
|
||||
|
||||
*August 2023*
|
||||
|
||||
## Motivation
|
||||
|
||||
We have an efficiency problem in automatically transposing `shmap`s containing
|
||||
certain collectives. The issue arises with `psum` and `all_gather`, specifically
|
||||
when the output of the collective is returned to the caller as an unmapped
|
||||
output. And it's not an edge case: for example, it arises when applying `grad`
|
||||
to a `shmap`-based batch data parallel neural network loss function which uses
|
||||
`psum` to compute the total loss.
|
||||
|
||||
We've known about this problem for some time. An analogous issue exists with
|
||||
`pmap`, though it's been worked around by keeping `grad` inside `pmap` rather than
|
||||
outside. A primary goal of the incomplete avals-with-names work was to address a
|
||||
version of this transpose efficiency problem. This doc draws on those ideas,
|
||||
while extending and revising them to handle more cases and to be much easier to
|
||||
land. Indeed the solution proposed here only affects the `shmap` implementation.
|
||||
The rest of the system need not be changed (yet).
|
||||
|
||||
The main purpose of this doc is to define this transpose efficiency problem and
|
||||
propose an easy-to-land solution.
|
||||
|
||||
This doc is not about:
|
||||
* logical axis names on arrays (the only axis names here are just like in
|
||||
`shmap` and OG `pmap`);
|
||||
* changing autodiff semantics (all the numbers and (non)errors are staying the
|
||||
same, we're just making things more efficient);
|
||||
* allowing user code to reflect on any new information, or really affecting user
|
||||
code at all.
|
||||
|
||||
## Problem: efficient transpose of `psum` or `all_gather` depends on whether cotangents are invariant across devices
|
||||
|
||||
Consider this semi-realistic example, meant to resemble a replicated-parameter
|
||||
batch data parallel loss function:
|
||||
|
||||
```python
|
||||
devices = jax.devices() # 8 devices
|
||||
|
||||
@partial(shmap, mesh=Mesh(devices, ('batch',)),
|
||||
in_specs=(P(None, None), P('batch', None)),
|
||||
out_specs=P())
|
||||
def loss(params, batch):
|
||||
inputs, targets = batch
|
||||
predictions = predict(params, inputs)
|
||||
local_loss = jnp.mean(jnp.sum(predictions - targets, -1))
|
||||
global_loss = lax.pmean(local_loss, 'batch'))
|
||||
return global_loss
|
||||
```
|
||||
|
||||
Notice the `out_specs=P()`, which indicates an unmapped output. If you're not
|
||||
familiar with the notion of unmapped outputs, see the appendix at the bottom of
|
||||
this document.
|
||||
|
||||
Most of the details in the `loss` example aren't important. All that matters for
|
||||
our purposes is that we're applying `psum` (or rather `pmean = lambda x, name:
|
||||
psum(x, name) / psum(1, name)`) at the end. So a distilled version looks like
|
||||
this:
|
||||
|
||||
```python
|
||||
# Example 1: shmap involving psum and unmapped output with inefficient transpose
|
||||
f1 = shmap(lambda x: psum(g(x), 'i'),
|
||||
in_specs=P('i'), out_specs=P())
|
||||
```
|
||||
|
||||
We even simplified notation by suppressing the `mesh` argument. In the examples to
|
||||
follow it can be inferred from context.
|
||||
|
||||
What does the transpose look like? Writing `t` to mean function transpose, we
|
||||
could evaluate `t(f1)(ybar)` for any `ybar` efficiently by applying the function
|
||||
`¿f1_transpose?` below:
|
||||
|
||||
```python
|
||||
# An efficient "transpose" of Example 1 (but don't transpose this again!)
|
||||
¿f1_transpose? = shmap(t(g), in_specs=P(), out_specs=P('i'))
|
||||
```
|
||||
|
||||
But that's not the transpose we currently get as t(f1).
|
||||
|
||||
Instead, the current recipe for transposition is roughly that we switch
|
||||
`in_specs` and `out_specs`, do some division rescaling for unmapped outputs, and
|
||||
transpose the body. Because `psum` is its own transpose (as an all-reduce sum),
|
||||
we end up producing this transpose:
|
||||
|
||||
```python
|
||||
# The transpose we currently get for Example 1 (which is fine to transpose again)
|
||||
t(f1) = shmap(lambda ybar: t(g)(psum(ybar / 8, 'i')),
|
||||
in_specs=P(), out_specs=P('i'))
|
||||
```
|
||||
|
||||
This transpose gets the numbers right, but it's wasteful. We know statically
|
||||
from the transpose's `in_specs=P()` that `ybar` has the same value for each function
|
||||
instance, i.e. that its value is device-invariant for devices along the mesh
|
||||
axis named `i`, and yet we apply a `psum` to it! That uses expensive communication
|
||||
just to multiply the value on each device by 8. (Here 8 refers to the size of
|
||||
axis i. The division by 8 comes from the original function's `out_specs=P()`; it
|
||||
and the trivial `psum` basically cancel each other out.)
|
||||
|
||||
What are we doing wrong? We're not exploiting the fact that cotangents `ybar`
|
||||
corresponding to `f1`'s unmapped outputs are guaranteed to be device-invariant;
|
||||
instead, we're defensively `psum`ming them as if they weren't because `psum`'s
|
||||
transpose can't be sure given the local information it has. Sometimes the `psum`
|
||||
is necessary, as in transposing `f2` with respect to its first argument:
|
||||
|
||||
```python
|
||||
# Example 2: shmap involving psum and *mapped* output with efficient transpose
|
||||
f2 = shmap(lambda x, y: psum(g(x), 'i') * y,
|
||||
in_specs=(P('i'), P('i')), out_specs=P('i'))
|
||||
|
||||
# The transpose we currently get for Example 2 is efficient
|
||||
t(f2, 0) = shmap(lambda y, zbar: t(g)(psum(zbar * y, 'i')),
|
||||
in_specs=(P('i'), P('i')), out_specs=P('i'))
|
||||
```
|
||||
|
||||
Intuitively, if our transpose machinery could tell the difference between
|
||||
Example 1 and Example 2, we could do better by avoiding the psum and division
|
||||
where possible.
|
||||
|
||||
The inefficient examples can be even smaller. Consider transposing this cursed
|
||||
identity function:
|
||||
|
||||
```python
|
||||
# Example 3: cursed identity
|
||||
cursed_identity = shmap(lambda x: x, P(), P())
|
||||
|
||||
# Currently we get these inefficient transposes
|
||||
t(cursed_identity) = shmap(lambda x: psum(x / 8, 'i'), P(), P())
|
||||
t(t(cursed_identity)) = shmap(lambda x: psum(psum(x / 8 / 8, 'i'), 'i')), P(), P())
|
||||
...
|
||||
```
|
||||
|
||||
It keeps getting bigger the more we transpose. How embarrassing!
|
||||
|
||||
And `psum` isn't the only culprit. Something analogous holds true for
|
||||
`all_gather`:
|
||||
|
||||
```python
|
||||
# Example 4: all_gather to an unmapped output
|
||||
f4 = shmap(lambda x: all_gather(x, 'i'), P('i'), P())
|
||||
|
||||
# Currently we get this inefficient transpose
|
||||
t(f4) = shmap(lambda ybar: psum_scatter(ybar / 8, 'i'), P(), P('i'))
|
||||
```
|
||||
|
||||
This program is a bit artificial. Why do an `all_gather` and feed the result into
|
||||
an unmapped output, rather than skipping the `all_gather` in the body and just
|
||||
using `out_specs=P('i')` to collect the results? But even though it's cooked-up,
|
||||
this example nevertheless exhibits a transpose which unnecessarily performs
|
||||
communication (we could have just performed a non-communicating slice),
|
||||
analogous to Example 1 for `psum`.
|
||||
|
||||
Also analogously to the `psum` examples, the defensive `psum_scatter` is
|
||||
necessary in some cases:
|
||||
|
||||
```python
|
||||
# Example 5: all_gather to a mapped output
|
||||
f5 = shmap(lambda x, y: all_gather(x, 'i') * y,
|
||||
in_specs=(P('i'), P('i')), out_specs=P('i'))
|
||||
|
||||
# Currently we get this efficient transpose
|
||||
t(f5, 0) = shmap(lambda y, zbar: psum_scatter(zbar * y, 'i'),
|
||||
in_specs=(P('i'), P('i')), out_specs=P('i'))
|
||||
```
|
||||
|
||||
So how do we avoid these inefficient transposes?
|
||||
|
||||
## Solutions
|
||||
|
||||
Here are two solution ideas. They aren't mutually exclusive. But (spoilers) the
|
||||
second one is better, and it's all we need.
|
||||
|
||||
### Partial solution "P-sum": build the ability to express a `psum` into `out_specs`
|
||||
|
||||
This solution is a bit of a strawperson because it would offer only an awkward
|
||||
way to write programs. And it wouldn't even fix everything! But it's worth
|
||||
considering, if only to motivate a more complete solution.
|
||||
|
||||
Example 4 above is artificial because we could have just used `out_specs` instead
|
||||
of an `all_gather` in the body:
|
||||
|
||||
```python
|
||||
# Example 4 again
|
||||
f4 = shmap(lambda x: all_gather(x, 'i'), P('i'), P())
|
||||
|
||||
# Why didn't we just write it like this?
|
||||
f4_better = shmap(lambda x: x, P('i'), P('i'))
|
||||
```
|
||||
|
||||
The `f4_better` version doesn't have any transposition problems, since the
|
||||
transpose problems arise from collectives in the body.
|
||||
|
||||
Analogously, we could fix Example 1 by extending `out_specs` so that they can
|
||||
express summing:
|
||||
|
||||
```python
|
||||
# Example 1 again
|
||||
f1 = shmap(lambda x: psum(g(x), 'i'),
|
||||
in_specs=P('i'), out_specs=P())
|
||||
|
||||
# What if we could write an output sum like this?
|
||||
f1_better = shmap(g, in_specs=P('i'), out_specs=P(sum='i')) # sum='i' means sum over that axis
|
||||
|
||||
# Then it could transpose like this:
|
||||
t(f1_better) = shmap(t(g), in_specs=P(), out_specs=P('i'))
|
||||
t(t(f1_better)) = shmap(t(t(g)), in_specs=P('i'), P(sum='i'))
|
||||
```
|
||||
|
||||
So offering `psum`s built into `out_specs` fixes the transpose problem of
|
||||
Example 1. But it doesn't fully fix the cursed identity transpose in Example 3:
|
||||
|
||||
```python
|
||||
# Example 3 again
|
||||
cursed_identity = shmap(lambda x: x, P(), P())
|
||||
|
||||
# How it would transpose with the P-sum partial solution:
|
||||
t(cursed_identity) = shmap(lambda x: x / 8, P(), P(sum='i'))
|
||||
t(t(cursed_identity)) = shmap(lambda x: x / 8, P(), P(sum='i'))
|
||||
```
|
||||
|
||||
It's an improvement since the program doesn't continue to get bigger as we keep
|
||||
transposing, but we're still doing wasteful communication.
|
||||
|
||||
### Full solution: statically track device-varying vs device-invariant intermediates, plus new primitives
|
||||
|
||||
This solution has two components:
|
||||
1. track when values are guaranteed to be device-invariant vs device-varying
|
||||
over particular mesh axes, and
|
||||
2. decompose `psum` into a two-step process, introducing a new `pbroadcast`
|
||||
primitive, and introduce new primitives for `all_gather` and its transposes.
|
||||
|
||||
Morally, the tracking of device-invariant vs device-varying information is a
|
||||
type-level consideration. But for the expedience of our first implementation, we
|
||||
don't need to literally add the information to abstract values or jaxpr types.
|
||||
Before we get to implementation, we'll first introduce the idea using types.
|
||||
|
||||
Also to follow is a discussion of making the user API convenient and backward
|
||||
compatible. But to first introduce the idea, we'll ignore convenience and
|
||||
instead write code that is as explicit as possible.
|
||||
|
||||
#### Tracking device invariance in avals (a.k.a. avals-with-names, revived)
|
||||
|
||||
We can sometimes tell from static information alone that the values of some
|
||||
intermediate variables in the body of a `shmap` are guaranteed to be invariant
|
||||
along a mesh axis, in the sense that the function instances (and their
|
||||
corresponding devices) along the mesh axis must all be computing with the same
|
||||
value. We'll call such values device-invariant. For values that are not
|
||||
device-invariant, we'll say they're device-varying, though really we mean
|
||||
potentially device-varying from the point of view of the type system.
|
||||
|
||||
To encode device variance in types, we'll extend the syntax of types for arrays.
|
||||
We'll write things like `x:f32[3,4]{i}` to indicate that `x` is (potentially)
|
||||
device-varying along mesh axis `i` (and device-invariant over any other mesh
|
||||
axes of the `shmap`). More generally, we'll say the grammar for array type
|
||||
syntax is something like
|
||||
|
||||
```
|
||||
shaped_array ::= <dtype>[<int_literal>, ...]<device_variance_type>
|
||||
device_variance_type ::= {<axis_name>, ...}
|
||||
```
|
||||
|
||||
We'll also update the typing rules to handle device variance types:
|
||||
* for first-order primitives other than collectives
|
||||
- for multi-arity primitives, the operand device variance types must be equal
|
||||
where shapes must be equal, e.g. `mul x:f32[s1]{r1} y:f32[s2][r2]` requires
|
||||
`r1 == r2` in addition to `s1 == s2`
|
||||
- the output device variance type must be the same as the operand(s)
|
||||
* for higher-order primitives
|
||||
- we just instantiate any type variables including the device variance type
|
||||
(and checking types for equality checks their device variance types are
|
||||
equal)
|
||||
- (when performing type inference, e.g. for branches of a `cond`, we take the
|
||||
union of the sets of axis names in device variance types)
|
||||
* for first-order collectives
|
||||
- a collective can either accept a device-varying or device-invariant input
|
||||
(along a mesh axis corresponding to its axis name parameter); it's an error
|
||||
to pass a device-invariant operand to a collective which accepts
|
||||
device-varying operands and vice-versa
|
||||
- a collective can either produce a device-varying or device-invariant output
|
||||
- see the table below
|
||||
As a side benefit, whatever logic implements this type checking can subsume
|
||||
`shmap`'s "static analysis" check for whether a `shmap` body function is
|
||||
compatible with any unmapped `out_specs`.
|
||||
|
||||
Here's a table summarizing the device variance typing for collective primitives:
|
||||
|
||||
| Name | Device variance type | Example | Lowers to HLO | Transpose |
|
||||
| --- | --- | --- | --- | --- |
|
||||
| `psum2` | `Varying -> Invariant` | `y:f32[3]{j} = psum(x:f32[3]{i,j}, axis='i')` | `AllReduceSum` (communication) | `pbroadcast` |
|
||||
| `pbroadcast` | `Invariant -> Varying` | `y:f32[3]{i} = pbroadcast(x:f32[3], 'i')` | no-op (no communication) | `psum` |
|
||||
| `all_to_all` | `Varying -> Varying` | `y:f32[16]{i} = all_to_all(x:f32[16]{i}, 'i', 0, 0)` `AllToAll` (communication) | `all_to_all` |
|
||||
| `axis_index` | `() -> Varying` | `idx:i32[]{i} = axis_index('i')` | `ReplicaId` and some arithmetic (no communication) | n/a |
|
||||
| `psum_scatter` | `Varying -> Varying` | `y:f32[2]{i} = psum_scatter(x:f32[16]{i}, 'i')` | `ReduceScatterSum` (communication) | `all_gather` |
|
||||
| `all_gather` | `Varying -> Varying` | `y:f32[16]{i} = all_gather(x:f32[2]{i}, 'i')` | `AllGather` (communication) | `psum_scatter` |
|
||||
| `pscatter` | `Invariant -> Varying` | `y:f32[2]{i} = pscatter(x:f32[16], 'i')` | `lambda x: x[axis_index('i'), None]` (no communication) | `all_gather_invariant` |
|
||||
| `all_gather_invariant` | `Varying -> Invariant` | `y:f32[16] = all_gather_invariant(x:f32[2]{i}, 'i')` | `AllGather` (communication) | `pscatter` |
|
||||
|
||||
|
||||
There are some surprising things here!
|
||||
* We introduced several new primitives, including
|
||||
- `pbroadcast`, which interestingly lowers to a no-op
|
||||
- `all_gather_invariant`, which lowers to the same thing as `all_gather` but
|
||||
has a different device variance type (essentially `all_gather` has a
|
||||
`pbroadcast` fused into it, whereas `all_gather_invariant` does not)
|
||||
- `pscatter` which is the dual (transpose) of `all_gather_invariant`
|
||||
* all_gather has a device-varying result
|
||||
|
||||
Intuitively, the reason to introduce `pbroadcast` (other than to make the typing
|
||||
rules work) is so that `psum` can transpose to a physical no-op. The reason we
|
||||
need `all_gather` to have a device-varying result is so that we can transpose it
|
||||
to `psum_scatter`; if we instead left it with a device-invariant result, we
|
||||
might need a downstream `pbroadcast`, and that composition would transpose to an
|
||||
inefficient `psum` followed by slicing / `pscatter`. So instead we have a
|
||||
`pbroadcast` "fused into" the `all_gather`, thus allowing for an efficient
|
||||
transpose into `psum_scatter`. We provide `all_gather_invariant` and its
|
||||
transpose `pscatter` mainly for completeness; it's unlikely users will need it
|
||||
(it corresponds to the situation in Example 4, which is easy to write
|
||||
differently using `out_specs`).
|
||||
|
||||
Interestingly, the `psum` and `pbroadcast` transpose pair correspond to the
|
||||
`psum_idrev` and `id_psumrev` that users introduced while training LLMs with
|
||||
`pmap`.
|
||||
|
||||
#### How this system solves the inefficient transpose examples
|
||||
|
||||
Consider again the simplified motivating example:
|
||||
|
||||
```python
|
||||
# Example 1 again
|
||||
f1 = shmap(lambda x: psum(g(x), 'i'),
|
||||
in_specs=P('i'), out_specs=P())
|
||||
|
||||
# Example 1 with intermediate device variance types annotated
|
||||
@partial(shmap, in_specs=P('i'), out_specs=P())
|
||||
def f1(x: f32[3,4]{i}):
|
||||
w:f32[]{i} = g(x)
|
||||
y:f32[]{} = psum(w, 'i')
|
||||
return y
|
||||
```
|
||||
|
||||
With these new rules, the transpose is:
|
||||
|
||||
```python
|
||||
# Example 1 transpose using device variance types (go ahead and transpose this again!)
|
||||
t(f1) = shmap(lambda ybar: t(g)(pbroadcast(ybar, 'i')),
|
||||
in_specs=P(), out_specs=P('i'))
|
||||
|
||||
# Example 1 transpose with intermediate device variance types annotated
|
||||
@partial(shmap, in_specs=P('i'), out_specs=P())
|
||||
def f1_transpose(ybar: f32[]):
|
||||
wbar:f32[]{i} = pbroadcast(ybar, 'i')
|
||||
xbar:f32[3,4]{i} = transpose(g)(wbar)
|
||||
return xbar
|
||||
```
|
||||
|
||||
where evaluating the `pbroadcast` application involves no communication or FLOPs
|
||||
at all; it's a no-op. Notice that if we keep transposing the body does not grow
|
||||
in size; indeed `t(t(f1)) == f1`. Efficiency achieved!
|
||||
|
||||
And we wouldn't mess up the other examples either, so long as we `pbroadcast` to
|
||||
make the types check where needed:
|
||||
|
||||
```python
|
||||
# Example 2 rewritten with explicit pbroadcast
|
||||
f2 = shmap(lambda x, y: pbroadcast(psum(g(x), 'i'), 'i') * y,
|
||||
in_specs=(P('i'), P('i')), out_specs=P('i'))
|
||||
|
||||
# Example 2 transpose using device variance types
|
||||
t(f2, 0) = shmap(lambda y, zbar: t(g)(pbroadcast(psum(zbar * y, 'i'), 'i')),
|
||||
in_specs=(P('i'), P('i')), out_specs=P('i'))
|
||||
|
||||
|
||||
# Example 3 again
|
||||
cursed_identity = shmap(lambda x: x, P(), P())
|
||||
# Notice here the body is `f32[...] -> f32[...]`, i.e. no device varying type.
|
||||
|
||||
# Example 3 transpose using device variance types
|
||||
t(cursed_identity) = shmap(lambda x: x, P(), P())
|
||||
t(t(cursed_identity)) = shmap(lambda x: x, P(), P())
|
||||
```
|
||||
|
||||
Intuitively, in Example 1 we now only have "half the original psum", whereas in
|
||||
Example 2 we get both "halves". For Example 3 we never need any operations in
|
||||
the body at all.
|
||||
|
||||
For the `all_gather` examples, Example 4 would need to use
|
||||
`all_reduce_invariant` to have an efficient transpose (though it'd be better to
|
||||
instead use `out_specs` instead of the collective in the body):
|
||||
|
||||
```python
|
||||
# Example 4 rewritten with explicit all_reduce_invariant
|
||||
f4 = shmap(lambda x: all_gather_invariant(x, 'i'), P('i'), P())
|
||||
|
||||
# Example 4 with intermediate device variance types annotated
|
||||
@partial(shmap, P('i'), P())
|
||||
def f4(x:f32[1]{i}):
|
||||
y:f32[8]{} = all_gather_invariant(x, 'i')
|
||||
return y
|
||||
|
||||
# Example 4 transpose with intermediate device variance types annotated
|
||||
@partial(shmap, in_specs=P(), out_specs=P('i'))
|
||||
def f4_transpose(ybar:f32[8]):
|
||||
xbar:f32[1]{i} = pscatter(ybar, 'i')
|
||||
return xbar
|
||||
```
|
||||
|
||||
For Example 5, using the device-varying `all_gather` works as we'd want:
|
||||
|
||||
```python
|
||||
# Example 5 with intermediate device variance types annotated
|
||||
@partial(shmap, in_specs=(P('i'), P('i')), out_specs=P('i'))
|
||||
def f5(x:f32[1]{i}, y:f32[8]{i}):
|
||||
z:f32[8]{i} = all_gather(x, 'i')
|
||||
w:f32[8]{i} = z * y
|
||||
return w
|
||||
|
||||
# Transpose with respect to first argument
|
||||
@partial(shmap, in_specs=(P('i'), P('i')), out_specs=P('i'))
|
||||
def f5_transpose(y:f32[8]{i}, wbar:f32[8]{i}):
|
||||
zbar:f32[8]{i} = wbar * y
|
||||
xbar:f32[1]{i} = psum_scatter(zbar, 'i')
|
||||
return xbar
|
||||
```
|
||||
|
||||
### How to make the API convenient for users (and backward compatible)
|
||||
|
||||
But what user wants to write `pbroadcast`s? And what developer wants to break
|
||||
lots of existing user code involving `psum`s which are not fed into unmapped
|
||||
outputs? Not me!
|
||||
|
||||
Instead we can automatically insert the `pbroadcast`s. It's a bit analogous to how
|
||||
we do automatic rank promotion at the `jax.numpy` layer, inserting broadcasts to
|
||||
avoid rank mismatch errors in binary operators. But it's much simpler since we
|
||||
don't need to contend with shape tuples. The typical rule is: whenever we see a
|
||||
multi-arity operation where the operands disagree in their device variance
|
||||
types, take the union of operands' device variance types' axis name sets and
|
||||
insert `pbroadcast`s to lift each operand to the resulting device variance type.
|
||||
|
||||
Automatically inserting `pbroadcast`s just before they're needed may mean we apply
|
||||
the same `pbroadcast` to the same operand multiple times, creating common
|
||||
subexpressions. When we transpose, those could turn into a sum-of-`psum`s rather
|
||||
than a `psum`-of-sum. We'll rely on the compiler to clean that up as appropriate.
|
||||
If it's a problem then we could add some simple memoization to the
|
||||
`pbroadcast`-insertion pass.
|
||||
|
||||
The user API for `all_gather` will mean `all_gather_p` by default (not
|
||||
`all_gather_invariant_p`), covering the common case and meaning no `pbroadcast`s
|
||||
must be inserted.
|
||||
|
||||
We can provide an option on `shmap` to disable this automatic insertion of
|
||||
`pbroadcast`s, in which case it'll be up to the user to ensure type-correctness.
|
||||
This explicit option may be appealing to some who want to be explicit about
|
||||
where the `psum`s occur in the backward pass.
|
||||
|
||||
### How to implement the solution
|
||||
|
||||
The key to making the implementation lightweight is that **we aren't going to
|
||||
add these types to avals or jaxprs**. At least, not at first. That can be
|
||||
expensive because it requires updating the rest of JAX, e.g. all consumers of
|
||||
avals and jaxprs may need to handle the new types. We're not falling for that
|
||||
again!
|
||||
|
||||
Instead we're going to keep these extended types as metadata internal to
|
||||
`shmap`, just like the current "replication checking for `out_specs`" machinery
|
||||
is internal to `shmap`. Indeed this solution amounts to a relatively small
|
||||
extension to that existing machinery: it was already tracking the same
|
||||
information; now we're just adding the `pbroadcast`s.
|
||||
|
||||
We have at least two options for where to perform the `pbroadcast` insertion:
|
||||
1. just before transposition, in the transpose rule, where we have a jaxpr of
|
||||
the computation to be transposed;
|
||||
2. in every `shmap` body, whether eagerly executed or staged out, like the
|
||||
current "replication checking for `out_specs`" machinery.
|
||||
The former may end up being easier since we only have to handle the jaxpr case,
|
||||
and only linear primitives. But we'll start by trying the latter so the
|
||||
implementation here is a strict revision/extension to the existing
|
||||
replication-checking logic.
|
||||
|
||||
## Appendix: defining and motivating maps with unmapped inputs and outputs
|
||||
|
||||
For concreteness, we'll mostly focus on `shmap`, though these same ideas apply
|
||||
to e.g. `pmap` and probably `xmap`.
|
||||
|
||||
An argument/input is _unmapped_ along a mesh axis when the corresponding entry
|
||||
of `in_specs` doesn't mention that mesh axis's name. Logically it means that
|
||||
each function instance along that mesh axis gets the same value for the
|
||||
argument. To the caller, each operand is sliced according to the mesh axes over
|
||||
which the operand is mapped, whereas there is no slicing for mesh axes over
|
||||
which the operand is unmapped.
|
||||
|
||||
An output is _unmapped_ along a mesh axis when the corresponding entry of
|
||||
`out_specs` doesn't mention that mesh axis's name. Logically it means each
|
||||
function instance along that mesh axis must return the same value. To the
|
||||
caller, each result of the `shmap` is formed by concatenating the return values
|
||||
of every function instance along which the outputs are mapped, whereas for mesh
|
||||
axes over which the output is unmapped only one copy of the value is used.
|
||||
|
||||
See [the `shmap`
|
||||
JEP](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html) for examples
|
||||
of unmapped inputs and outputs. For comparison, in `vmap` unmapped
|
||||
inputs/outputs are indicated by using `in_axes` / `out_axes` of `None` (rather
|
||||
than an `int`).
|
||||
|
||||
Here are reasons we like unmapped inputs and outputs for `shmap`:
|
||||
* **Same expressiveness as `pjit`.** Anything `pjit` can do, the `shmap` escape
|
||||
hatch should be able to do too. Or else we'd have a lacking escape hatch! If
|
||||
we didnt have unmapped outputs in `shmap` then we couldn't express the same
|
||||
batch-parallel loss function computations as `pjit`.
|
||||
* **Closed-over inputs.** Closed-over inputs essentially correspond to unmapped
|
||||
inputs, and...
|
||||
* **Closure under transposition.** Once we have unmapped inputs, it's natural to
|
||||
be able to transpose to unmapped outputs.
|
||||
|
||||
So unmapped outputs are both canonical and useful!
|
@ -48,6 +48,7 @@ Then create a pull request that adds a file named
|
||||
12049: Type Annotation Roadmap for JAX <12049-type-annotations>
|
||||
14273: `shard_map` (`shmap`) for simple per-device code <14273-shard-map>
|
||||
15856: `jax.extend`, an extensions module <15856-jex>
|
||||
17111: Efficient transposition of `shard_map` (and other maps) <17111-shmap-transpose>
|
||||
|
||||
|
||||
Several early JEPs were converted in hindsight from other documentation,
|
||||
|
@ -1084,8 +1084,7 @@ def _why_alive_container_info(container, obj_id) -> str:
|
||||
|
||||
|
||||
@contextmanager
|
||||
def new_main(trace_type: type[Trace],
|
||||
dynamic: bool = False,
|
||||
def new_main(trace_type: type[Trace], dynamic: bool = False,
|
||||
**payload) -> Generator[MainTrace, None, None]:
|
||||
# See comments in https://github.com/google/jax/pull/3370
|
||||
stack = thread_local_state.trace_state.trace_stack
|
||||
@ -1111,6 +1110,20 @@ def new_main(trace_type: type[Trace],
|
||||
leaked_tracers = maybe_find_leaked_tracers(t())
|
||||
if leaked_tracers: raise leaked_tracer_error("trace", t(), leaked_tracers)
|
||||
|
||||
@contextmanager
|
||||
def new_dynamic(level: int) -> Generator[None, None, None]:
|
||||
stack = thread_local_state.trace_state.trace_stack
|
||||
prev_dynamic, stack.dynamic = stack.dynamic, stack.stack[level]
|
||||
_update_thread_local_jit_state(stack.dynamic)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
stack.dynamic = prev_dynamic
|
||||
_update_thread_local_jit_state(stack.dynamic)
|
||||
|
||||
def dynamic_level() -> int:
|
||||
return thread_local_state.trace_state.trace_stack.dynamic.level
|
||||
|
||||
@contextmanager
|
||||
def new_base_main(trace_type: type[Trace],
|
||||
**payload) -> Generator[MainTrace, None, None]:
|
||||
|
@ -853,6 +853,7 @@ def _custom_vjp_call_jaxpr_jvp(
|
||||
tangents_out = ad.custom_lin_p.bind(
|
||||
*res, *args_dot, num_res=res_tree.num_leaves, bwd=bwd,
|
||||
out_avals=avals_out, symbolic_zeros=symbolic_zeros)
|
||||
tangents_out = map(lax.tie_p.bind, primals_out, tangents_out)
|
||||
tangents_out = map(ad.recast_to_float0, primals_out, tangents_out)
|
||||
return primals_out, tangents_out
|
||||
ad.primitive_jvps[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_jvp
|
||||
|
@ -403,6 +403,7 @@ class JVPTrace(Trace):
|
||||
tangents_out = custom_lin_p.bind(
|
||||
*res, *tangents_in, num_res=res_tree.num_leaves, bwd=bwd,
|
||||
out_avals=avals_out, symbolic_zeros=symbolic_zeros)
|
||||
tangents_out = map(jax._src.lax.lax.tie_p.bind, primals_out, tangents_out)
|
||||
tangents_out = map(recast_to_float0, primals_out, tangents_out)
|
||||
return map(partial(JVPTracer, self), primals_out, tangents_out)
|
||||
|
||||
|
@ -4966,6 +4966,17 @@ def _empty_lower(ctx, *, dtype):
|
||||
mlir.register_lowering(empty_p, _empty_lower)
|
||||
|
||||
|
||||
tie_p = core.Primitive('tie')
|
||||
tie_p.def_impl(lambda x, y: y)
|
||||
tie_p.def_abstract_eval(lambda x, y: y)
|
||||
mlir.register_lowering(tie_p, lambda ctx, x, y: [y])
|
||||
ad.primitive_jvps[tie_p] = \
|
||||
lambda primals, tangents: (tie_p.bind(*primals), tangents[-1])
|
||||
ad.primitive_transposes[tie_p] = lambda ct, x, _: [None, ct]
|
||||
pe.def_trivial_padding(tie_p)
|
||||
batching.defvectorized(tie_p)
|
||||
|
||||
|
||||
class BIntRules:
|
||||
@staticmethod
|
||||
def physical_element_aval(dtype) -> core.ShapedArray:
|
||||
|
@ -812,6 +812,7 @@ batching.axis_primitive_batchers[psum_p] = \
|
||||
partial(_batched_reduction_collective, psum_p, lambda v, axis_size: axis_size * v)
|
||||
core.axis_substitution_rules[psum_p] = partial(_subst_all_names_in_param, 'axes')
|
||||
|
||||
|
||||
# We set a special bind rule for psum so that psum(1, 'i') can be evaluated at
|
||||
# tracing time.
|
||||
@psum_p.def_custom_bind
|
||||
|
@ -1465,6 +1465,8 @@ tf_not_yet_impl = [
|
||||
"pmin",
|
||||
"ppermute",
|
||||
"psum",
|
||||
"psum2",
|
||||
"pbroadcast",
|
||||
"pmax",
|
||||
"pgather",
|
||||
"reduce_scatter",
|
||||
@ -3449,6 +3451,9 @@ def _reduce_precision(x, *, exponent_bits, mantissa_bits):
|
||||
|
||||
tf_impl[lax.reduce_precision_p] = _reduce_precision
|
||||
|
||||
tf_impl[lax_internal.tie_p] = lambda x, y: y
|
||||
|
||||
|
||||
def _register_checkpoint_pytrees():
|
||||
"""Registers TF custom container types as pytrees."""
|
||||
m = tf.Module()
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -37,6 +37,7 @@ from jax._src import xla_bridge
|
||||
from jax._src.util import safe_zip, safe_map, partition_list, merge_lists
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import tree_util
|
||||
import jax.numpy as jnp
|
||||
|
||||
@ -113,9 +114,12 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
mesh, a, _ = create_inputs(P('z', ('x', 'y')), P(None, None))
|
||||
assert a.device_buffers[0].shape == (4, 2)
|
||||
|
||||
# NOTE(mattjj): to use out_specs=P(None, ('x', 'y')), we need to use
|
||||
# all_gather_invariant primitive, which differs in its output replication
|
||||
# type compared to all_gather.
|
||||
@jax.jit
|
||||
@partial(shard_map, mesh=mesh,
|
||||
in_specs=(P('z', ('x', 'y')),), out_specs=P(None, ('x', 'y')))
|
||||
in_specs=(P('z', ('x', 'y')),), out_specs=P('z', ('x', 'y')))
|
||||
def fwd(a):
|
||||
return lax.all_gather(a, 'z', axis=0, tiled=True)
|
||||
|
||||
@ -559,8 +563,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
def test_partial_eval_custom_axis_env(self):
|
||||
mesh = Mesh(jax.devices(), ('i',))
|
||||
|
||||
@partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i'),
|
||||
check_rep=False) # check_rep=False b/c no scan rep rule yet
|
||||
@partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i'))
|
||||
def f(_):
|
||||
_, idx = jax.lax.scan(lambda _, __: (None, jax.lax.axis_index('i')),
|
||||
None, None, length=1)
|
||||
@ -861,6 +864,274 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
self.assertIn("\"[('i',)]\"", mhlo_str)
|
||||
self.assertIn("-> (tensor<1xf32> {jax.result_info = \"[('i',)]\"})", mhlo_str)
|
||||
|
||||
def test_rewrite_process_call(self):
|
||||
def f(x):
|
||||
return core.call_p.bind(lu.wrap_init(lambda x: [2. * x]), x)[0] * x
|
||||
|
||||
mesh = jtu.create_global_mesh((4,), ('x',))
|
||||
g = shard_map(f, mesh, in_specs=(P('x'),), out_specs=P('x'))
|
||||
x = jnp.arange(4.)
|
||||
y = jax.jit(g)(x) # eager requires shmap to have ShardMapTrace.process_call
|
||||
self.assertAllClose(y, 2 * x * x, check_dtypes=True)
|
||||
|
||||
def test_rewrite_post_process_call(self):
|
||||
# We shouldn't hit post_process_call here because of RewriteTrace's dynamic
|
||||
# behavior (i.e. no data dependence).
|
||||
mesh = jtu.create_global_mesh((4,), ('x',))
|
||||
|
||||
@jax.jit
|
||||
@partial(shard_map, mesh=mesh, in_specs=(P('x'),), out_specs=P('x'))
|
||||
def f(x):
|
||||
return core.call_p.bind(lu.wrap_init(lambda: [2. * x]))[0] * x
|
||||
|
||||
x = jnp.arange(4.)
|
||||
y = f(x)
|
||||
self.assertAllClose(y, 2 * x * x, check_dtypes=True)
|
||||
|
||||
def test_rewrite_process_custom_jvp_call(self):
|
||||
@jax.custom_jvp
|
||||
def foo(x):
|
||||
return 2. * x
|
||||
|
||||
@foo.defjvp
|
||||
def foo_jvp(primals, tangents):
|
||||
(x,), (x_dot,) = primals, tangents
|
||||
return foo(x), 2. * x_dot
|
||||
|
||||
mesh = jtu.create_global_mesh((4,), ('x',))
|
||||
g = shard_map(lambda x: foo(x) * x, mesh,
|
||||
in_specs=(P('x'),), out_specs=P('x'))
|
||||
x = jnp.arange(4.)
|
||||
|
||||
y = jax.jit(g)(x)
|
||||
self.assertAllClose(y, 2 * x * x, check_dtypes=True)
|
||||
|
||||
y2, y_dot = jax.jvp(jax.jit(g), (x,), (3 * x,))
|
||||
self.assertAllClose(y2, 2 * x * x, check_dtypes=True)
|
||||
self.assertAllClose(y_dot, 2 * 2 * 3 * x * x, check_dtypes=True)
|
||||
|
||||
def test_rewrite_process_custom_vjp_call(self):
|
||||
@jax.custom_vjp
|
||||
def foo(x):
|
||||
return 2. * x
|
||||
|
||||
def foo_fwd(x):
|
||||
return foo(x), None
|
||||
|
||||
def foo_bwd(_, y_bar):
|
||||
return 2. * y_bar,
|
||||
|
||||
foo.defvjp(foo_fwd, foo_bwd)
|
||||
|
||||
mesh = jtu.create_global_mesh((4,), ('x',))
|
||||
g = shard_map(lambda x: foo(x) * x, mesh,
|
||||
in_specs=(P('x'),), out_specs=P('x'))
|
||||
|
||||
x = jnp.arange(4.)
|
||||
y = jax.jit(g)(x)
|
||||
self.assertAllClose(y, 2 * x * x, check_dtypes=True)
|
||||
|
||||
y_, x_bar = jax.value_and_grad(lambda x: jax.jit(g)(x).sum())(x)
|
||||
self.assertAllClose(y_, (2 * x * x).sum(), check_dtypes=True)
|
||||
self.assertAllClose(x_bar, 2 * 2 * x, check_dtypes=True)
|
||||
|
||||
def test_rewrite_process_custom_vjp_call_match_more_replicated(self):
|
||||
@jax.custom_vjp
|
||||
def foo(x):
|
||||
return 2. * x
|
||||
|
||||
def foo_fwd(x):
|
||||
return foo(x), None
|
||||
|
||||
def foo_bwd(_, y_bar):
|
||||
return jnp.ones_like(y_bar), # diff! more replicated than primal/tangent
|
||||
|
||||
foo.defvjp(foo_fwd, foo_bwd)
|
||||
|
||||
mesh = jtu.create_global_mesh((4,), ('x',))
|
||||
g = shard_map(lambda x: foo(x) * x, mesh,
|
||||
in_specs=(P('x'),), out_specs=P('x'))
|
||||
x = jnp.arange(4.)
|
||||
|
||||
y = jax.jit(g)(x)
|
||||
self.assertAllClose(y, 2 * x * x, check_dtypes=True)
|
||||
|
||||
y_, x_bar = jax.value_and_grad(lambda x: jax.jit(g)(x).sum())(x)
|
||||
self.assertAllClose(y_, (2 * x * x).sum(), check_dtypes=True)
|
||||
self.assertAllClose(x_bar, jnp.ones_like(x) + 2 * x, check_dtypes=True)
|
||||
|
||||
def test_rewrite_process_custom_vjp_call_match_less_replicated(self):
|
||||
@jax.custom_vjp
|
||||
def foo(x, y):
|
||||
del y
|
||||
return 2. * x
|
||||
|
||||
def foo_fwd(x, y):
|
||||
return foo(x, y), y
|
||||
|
||||
def foo_bwd(y, _):
|
||||
return y, None # diff! x_bar less replicated than primal/tangent
|
||||
|
||||
foo.defvjp(foo_fwd, foo_bwd)
|
||||
|
||||
mesh = jtu.create_global_mesh((4,), ('x',))
|
||||
g = shard_map(lambda x, y: foo(x, y) * y, mesh,
|
||||
in_specs=(P(), P('x')), out_specs=P('x'))
|
||||
x = jnp.arange(4.)
|
||||
y = jnp.arange(4 * 4.)
|
||||
|
||||
z = jax.jit(g)(x, y)
|
||||
self.assertAllClose(z, 2 * jnp.tile(x, (4,)) * y, check_dtypes=False)
|
||||
|
||||
z_, x_bar = jax.value_and_grad(lambda x, y: jax.jit(g)(x, y).sum())(x, y)
|
||||
self.assertAllClose(z.sum(), z_, check_dtypes=False)
|
||||
self.assertAllClose(x_bar, jnp.arange(16).reshape(4, 4).sum(0),
|
||||
check_dtypes=False)
|
||||
|
||||
def test_rewrite_custom_vjp_call_jaxpr(self):
|
||||
@jax.custom_vjp
|
||||
def foo(x):
|
||||
return 2. * x
|
||||
|
||||
def foo_fwd(x):
|
||||
return foo(x), None
|
||||
|
||||
def foo_bwd(_, y_bar):
|
||||
return 2. * y_bar,
|
||||
|
||||
foo.defvjp(foo_fwd, foo_bwd)
|
||||
|
||||
def foo_scan(x):
|
||||
y, _ = jax.lax.scan(lambda x, _: (foo(x), None), x, None, length=1)
|
||||
return y
|
||||
|
||||
mesh = jtu.create_global_mesh((4,), ('x',))
|
||||
g = shard_map(lambda x: foo_scan(x) * x, mesh,
|
||||
in_specs=(P('x'),), out_specs=P('x'))
|
||||
|
||||
x = jnp.arange(4.)
|
||||
y = jax.jit(g)(x)
|
||||
self.assertAllClose(y, 2 * x * x, check_dtypes=True)
|
||||
|
||||
y_, x_bar = jax.value_and_grad(lambda x: jax.jit(g)(x).sum())(x)
|
||||
self.assertAllClose(y_, (2 * x * x).sum(), check_dtypes=True)
|
||||
self.assertAllClose(x_bar, 2 * 2 * x, check_dtypes=True)
|
||||
|
||||
def test_transpose_identity(self):
|
||||
mesh = jtu.create_global_mesh((4,), ('x',))
|
||||
|
||||
@partial(shard_map, mesh=mesh, in_specs=P(), out_specs=P())
|
||||
def f(x):
|
||||
return x
|
||||
|
||||
jaxpr = jax.make_jaxpr(jax.vjp(f, 1.)[1])(1.)
|
||||
e, = jaxpr.jaxpr.eqns
|
||||
self.assertEmpty(e.params['jaxpr'].eqns)
|
||||
|
||||
jaxpr = jax.make_jaxpr(jax.vjp(jax.vjp(f, 1.)[1], 1.)[1])((1.,))
|
||||
e, = jaxpr.jaxpr.eqns
|
||||
self.assertEmpty(e.params['jaxpr'].eqns)
|
||||
|
||||
@partial(shard_map, mesh=mesh, in_specs=P(), out_specs=P())
|
||||
def g(x):
|
||||
return jax.jit(lambda x: x)(x)
|
||||
|
||||
jaxpr = jax.make_jaxpr(jax.vjp(g, 1.)[1])(1.)
|
||||
e, = jaxpr.jaxpr.eqns
|
||||
e1, e2 = e.params['jaxpr'].eqns
|
||||
self.assertEmpty(e1.outvars)
|
||||
self.assertEmpty(e2.params['jaxpr'].eqns)
|
||||
|
||||
def test_fanout_specs_transpose_to_psum(self):
|
||||
mesh = jtu.create_global_mesh((4,), ('x',))
|
||||
|
||||
@partial(shard_map, mesh=mesh, in_specs=P(), out_specs=P('x'))
|
||||
def f(x):
|
||||
return x
|
||||
|
||||
jaxpr = jax.make_jaxpr(jax.vjp(f, jnp.arange(1.))[1])(jnp.arange(4.))
|
||||
e, = jaxpr.jaxpr.eqns
|
||||
e2, = e.params['jaxpr'].eqns
|
||||
self.assertEqual(str(e2.primitive), 'psum2')
|
||||
self.assertEqual(e2.params['axes'], ('x',))
|
||||
|
||||
def test_fanin_psum_transposes_to_fanout(self):
|
||||
mesh = jtu.create_global_mesh((4,), ('x',))
|
||||
|
||||
@partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P())
|
||||
def f(x):
|
||||
return jax.lax.psum(x, 'x')
|
||||
|
||||
jaxpr = jax.make_jaxpr(jax.vjp(f, jnp.arange(4.))[1])(jnp.array([1.]))
|
||||
e, = jaxpr.jaxpr.eqns
|
||||
e1, = e.params['jaxpr'].eqns
|
||||
self.assertEqual(str(e1.primitive), 'pbroadcast')
|
||||
|
||||
def test_psum_with_implicit_fanout_self_transposes(self):
|
||||
mesh = jtu.create_global_mesh((4,), ('x',))
|
||||
|
||||
@partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P('x'))
|
||||
def f(x):
|
||||
return jax.lax.psum(x, 'x')
|
||||
|
||||
jaxpr = jax.make_jaxpr(jax.vjp(f, jnp.arange(4.))[1])(jnp.arange(4.))
|
||||
e, = jaxpr.jaxpr.eqns
|
||||
e1, e2 = e.params['jaxpr'].eqns
|
||||
self.assertEqual(str(e1.primitive), 'psum2')
|
||||
self.assertEqual(str(e2.primitive), 'pbroadcast')
|
||||
|
||||
def test_rewrite_binops(self):
|
||||
mesh = jtu.create_global_mesh((4,), ('x',))
|
||||
|
||||
@partial(shard_map, mesh=mesh, in_specs=(P(), P('x')), out_specs=P('x'))
|
||||
def f(x, y):
|
||||
return x * y
|
||||
|
||||
jaxpr = jax.make_jaxpr(f)(jnp.arange(1.), jnp.arange(4.))
|
||||
e, = jaxpr.jaxpr.eqns
|
||||
e = e.params['jaxpr'].eqns[0]
|
||||
self.assertEqual(e.primitive.name, 'pbroadcast')
|
||||
self.assertEqual(e.params['axes'], ('x',))
|
||||
|
||||
def test_rewrite_scan(self):
|
||||
mesh = jtu.create_global_mesh((4,), ('x',))
|
||||
|
||||
@partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P('x'))
|
||||
def f(x):
|
||||
x, _ = jax.lax.scan(lambda x, _: (jax.lax.psum(x, 'x'), None), x, None,
|
||||
length=2)
|
||||
return x
|
||||
|
||||
jaxpr = jax.make_jaxpr(f)(jnp.arange(4.))
|
||||
e, = jaxpr.jaxpr.eqns
|
||||
e, = e.params['jaxpr'].eqns
|
||||
e1, e2 = e.params['jaxpr'].eqns
|
||||
self.assertEqual(e1.primitive.name, 'psum2')
|
||||
self.assertEqual(e2.primitive.name, 'pbroadcast')
|
||||
|
||||
def test_check_rep_false_grads(self):
|
||||
# This test is redundant with the systematic tests below, but it serves as a
|
||||
# direct regression test for a bug.
|
||||
mesh = jtu.create_global_mesh((4,), ('heads',))
|
||||
|
||||
def f(q, k, v):
|
||||
|
||||
def body(q, k, v):
|
||||
return q * k[None, :] + v[None, :]
|
||||
|
||||
out = shard_map(body, mesh, check_rep=False,
|
||||
in_specs=(q_spec, kv_spec, kv_spec,),
|
||||
out_specs=q_spec)(q, k, v)
|
||||
return out.sum()
|
||||
|
||||
q_spec = P('heads', None)
|
||||
kv_spec = P(None)
|
||||
q = jax.device_put(jnp.arange(32.).reshape(4, 8), jax.sharding.NamedSharding(mesh, q_spec))
|
||||
k = jax.device_put(jnp.arange(8.), jax.sharding.NamedSharding(mesh, kv_spec))
|
||||
v = jax.device_put(jnp.arange(8.), jax.sharding.NamedSharding(mesh, kv_spec))
|
||||
|
||||
jtu.check_grads(f, (q, k, v), order=1, modes=['rev'], rtol=1e-2)
|
||||
|
||||
|
||||
class FunSpec(NamedTuple):
|
||||
name: str
|
||||
@ -1126,12 +1397,15 @@ class ShardMapSystematicTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(expected, out, check_dtypes=False)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
sample(config.FLAGS.jax_num_generated_cases, sample_shmap))
|
||||
(name + f'_check_rep={check_rep}', *params, check_rep)
|
||||
for (name, *params) in sample(config.FLAGS.jax_num_generated_cases, sample_shmap)
|
||||
for check_rep in [True, False]
|
||||
)
|
||||
@jax.default_matmul_precision("float32")
|
||||
def test_grads(self, fun, mesh, jit, in_specs, out_specs, args, _):
|
||||
def test_grads(self, fun, mesh, jit, in_specs, out_specs, args, _, check_rep):
|
||||
mesh = self.make_mesh(mesh)
|
||||
args = map(jnp.array, args)
|
||||
f = shard_map(fun, mesh, in_specs, out_specs)
|
||||
f = shard_map(fun, mesh, in_specs, out_specs, check_rep=check_rep)
|
||||
if jit:
|
||||
f = jax.jit(f)
|
||||
jtu.check_grads(f, args, order=2, atol=1e-2, rtol=1e-2)
|
||||
|
Loading…
x
Reference in New Issue
Block a user