mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
rolling forward shard_map transpose fixes
The new efficient-transpose path, enabled by setting check_rep=True in the shard_map call, had kept working. But the change inadvertently broke the check_rep=False path. And because most tests set check_rep=True, we didn't notice it in the tests! The issue was that with check_rep=False, we need the shard_map transpose rule to insert psums corresponding to in_specs with fan-out, and correspondingly insert division for out_specs with fan-in-consensus. (With the new check_rep=True path that this change adds, those extra operations aren't necessary as the body itself transposes correctly.) But the PR accidentally removed those! The fix was simple: just track whether we've applied the efficient-transpose-body-rewrite (i.e. whether we're in the new body-is-transposable path or old need-extra-operations path) by adding a boolean parameter `rewrite` to the shard_map primitive, and if the rewrite hasn't been applied then include the explicit psum/div operations in the transpose rule. Reverts 8a04dfd830ff89f46e1fe3e866ee4fb2da9c90aa PiperOrigin-RevId: 561805840
This commit is contained in:
parent
c38f67043c
commit
70b58bbd30
515
docs/jep/17111-shmap-transpose.md
Normal file
515
docs/jep/17111-shmap-transpose.md
Normal file
@ -0,0 +1,515 @@
|
||||
# Efficient transposition of replication-inducing collectives
|
||||
*mattjj@*, *dougalm@*
|
||||
|
||||
*August 2023*
|
||||
|
||||
## Motivation
|
||||
|
||||
We have an efficiency problem in automatically transposing `shmap`s containing
|
||||
certain collectives. The issue arises with `psum` and `all_gather`, specifically
|
||||
when the output of the collective is returned to the caller as an unmapped
|
||||
output. And it's not an edge case: for example, it arises when applying `grad`
|
||||
to a `shmap`-based batch data parallel neural network loss function which uses
|
||||
`psum` to compute the total loss.
|
||||
|
||||
We've known about this problem for some time. An analogous issue exists with
|
||||
`pmap`, though it's been worked around by keeping `grad` inside `pmap` rather than
|
||||
outside. A primary goal of the incomplete avals-with-names work was to address a
|
||||
version of this transpose efficiency problem. This doc draws on those ideas,
|
||||
while extending and revising them to handle more cases and to be much easier to
|
||||
land. Indeed the solution proposed here only affects the `shmap` implementation.
|
||||
The rest of the system need not be changed (yet).
|
||||
|
||||
The main purpose of this doc is to define this transpose efficiency problem and
|
||||
propose an easy-to-land solution.
|
||||
|
||||
This doc is not about:
|
||||
* logical axis names on arrays (the only axis names here are just like in
|
||||
`shmap` and OG `pmap`);
|
||||
* changing autodiff semantics (all the numbers and (non)errors are staying the
|
||||
same, we're just making things more efficient);
|
||||
* allowing user code to reflect on any new information, or really affecting user
|
||||
code at all.
|
||||
|
||||
## Problem: efficient transpose of `psum` or `all_gather` depends on whether cotangents are invariant across devices
|
||||
|
||||
Consider this semi-realistic example, meant to resemble a replicated-parameter
|
||||
batch data parallel loss function:
|
||||
|
||||
```python
|
||||
devices = jax.devices() # 8 devices
|
||||
|
||||
@partial(shmap, mesh=Mesh(devices, ('batch',)),
|
||||
in_specs=(P(None, None), P('batch', None)),
|
||||
out_specs=P())
|
||||
def loss(params, batch):
|
||||
inputs, targets = batch
|
||||
predictions = predict(params, inputs)
|
||||
local_loss = jnp.mean(jnp.sum(predictions - targets, -1))
|
||||
global_loss = lax.pmean(local_loss, 'batch'))
|
||||
return global_loss
|
||||
```
|
||||
|
||||
Notice the `out_specs=P()`, which indicates an unmapped output. If you're not
|
||||
familiar with the notion of unmapped outputs, see the appendix at the bottom of
|
||||
this document.
|
||||
|
||||
Most of the details in the `loss` example aren't important. All that matters for
|
||||
our purposes is that we're applying `psum` (or rather `pmean = lambda x, name:
|
||||
psum(x, name) / psum(1, name)`) at the end. So a distilled version looks like
|
||||
this:
|
||||
|
||||
```python
|
||||
# Example 1: shmap involving psum and unmapped output with inefficient transpose
|
||||
f1 = shmap(lambda x: psum(g(x), 'i'),
|
||||
in_specs=P('i'), out_specs=P())
|
||||
```
|
||||
|
||||
We even simplified notation by suppressing the `mesh` argument. In the examples to
|
||||
follow it can be inferred from context.
|
||||
|
||||
What does the transpose look like? Writing `t` to mean function transpose, we
|
||||
could evaluate `t(f1)(ybar)` for any `ybar` efficiently by applying the function
|
||||
`¿f1_transpose?` below:
|
||||
|
||||
```python
|
||||
# An efficient "transpose" of Example 1 (but don't transpose this again!)
|
||||
¿f1_transpose? = shmap(t(g), in_specs=P(), out_specs=P('i'))
|
||||
```
|
||||
|
||||
But that's not the transpose we currently get as t(f1).
|
||||
|
||||
Instead, the current recipe for transposition is roughly that we switch
|
||||
`in_specs` and `out_specs`, do some division rescaling for unmapped outputs, and
|
||||
transpose the body. Because `psum` is its own transpose (as an all-reduce sum),
|
||||
we end up producing this transpose:
|
||||
|
||||
```python
|
||||
# The transpose we currently get for Example 1 (which is fine to transpose again)
|
||||
t(f1) = shmap(lambda ybar: t(g)(psum(ybar / 8, 'i')),
|
||||
in_specs=P(), out_specs=P('i'))
|
||||
```
|
||||
|
||||
This transpose gets the numbers right, but it's wasteful. We know statically
|
||||
from the transpose's `in_specs=P()` that `ybar` has the same value for each function
|
||||
instance, i.e. that its value is device-invariant for devices along the mesh
|
||||
axis named `i`, and yet we apply a `psum` to it! That uses expensive communication
|
||||
just to multiply the value on each device by 8. (Here 8 refers to the size of
|
||||
axis i. The division by 8 comes from the original function's `out_specs=P()`; it
|
||||
and the trivial `psum` basically cancel each other out.)
|
||||
|
||||
What are we doing wrong? We're not exploiting the fact that cotangents `ybar`
|
||||
corresponding to `f1`'s unmapped outputs are guaranteed to be device-invariant;
|
||||
instead, we're defensively `psum`ming them as if they weren't because `psum`'s
|
||||
transpose can't be sure given the local information it has. Sometimes the `psum`
|
||||
is necessary, as in transposing `f2` with respect to its first argument:
|
||||
|
||||
```python
|
||||
# Example 2: shmap involving psum and *mapped* output with efficient transpose
|
||||
f2 = shmap(lambda x, y: psum(g(x), 'i') * y,
|
||||
in_specs=(P('i'), P('i')), out_specs=P('i'))
|
||||
|
||||
# The transpose we currently get for Example 2 is efficient
|
||||
t(f2, 0) = shmap(lambda y, zbar: t(g)(psum(zbar * y, 'i')),
|
||||
in_specs=(P('i'), P('i')), out_specs=P('i'))
|
||||
```
|
||||
|
||||
Intuitively, if our transpose machinery could tell the difference between
|
||||
Example 1 and Example 2, we could do better by avoiding the psum and division
|
||||
where possible.
|
||||
|
||||
The inefficient examples can be even smaller. Consider transposing this cursed
|
||||
identity function:
|
||||
|
||||
```python
|
||||
# Example 3: cursed identity
|
||||
cursed_identity = shmap(lambda x: x, P(), P())
|
||||
|
||||
# Currently we get these inefficient transposes
|
||||
t(cursed_identity) = shmap(lambda x: psum(x / 8, 'i'), P(), P())
|
||||
t(t(cursed_identity)) = shmap(lambda x: psum(psum(x / 8 / 8, 'i'), 'i')), P(), P())
|
||||
...
|
||||
```
|
||||
|
||||
It keeps getting bigger the more we transpose. How embarrassing!
|
||||
|
||||
And `psum` isn't the only culprit. Something analogous holds true for
|
||||
`all_gather`:
|
||||
|
||||
```python
|
||||
# Example 4: all_gather to an unmapped output
|
||||
f4 = shmap(lambda x: all_gather(x, 'i'), P('i'), P())
|
||||
|
||||
# Currently we get this inefficient transpose
|
||||
t(f4) = shmap(lambda ybar: psum_scatter(ybar / 8, 'i'), P(), P('i'))
|
||||
```
|
||||
|
||||
This program is a bit artificial. Why do an `all_gather` and feed the result into
|
||||
an unmapped output, rather than skipping the `all_gather` in the body and just
|
||||
using `out_specs=P('i')` to collect the results? But even though it's cooked-up,
|
||||
this example nevertheless exhibits a transpose which unnecessarily performs
|
||||
communication (we could have just performed a non-communicating slice),
|
||||
analogous to Example 1 for `psum`.
|
||||
|
||||
Also analogously to the `psum` examples, the defensive `psum_scatter` is
|
||||
necessary in some cases:
|
||||
|
||||
```python
|
||||
# Example 5: all_gather to a mapped output
|
||||
f5 = shmap(lambda x, y: all_gather(x, 'i') * y,
|
||||
in_specs=(P('i'), P('i')), out_specs=P('i'))
|
||||
|
||||
# Currently we get this efficient transpose
|
||||
t(f5, 0) = shmap(lambda y, zbar: psum_scatter(zbar * y, 'i'),
|
||||
in_specs=(P('i'), P('i')), out_specs=P('i'))
|
||||
```
|
||||
|
||||
So how do we avoid these inefficient transposes?
|
||||
|
||||
## Solutions
|
||||
|
||||
Here are two solution ideas. They aren't mutually exclusive. But (spoilers) the
|
||||
second one is better, and it's all we need.
|
||||
|
||||
### Partial solution "P-sum": build the ability to express a `psum` into `out_specs`
|
||||
|
||||
This solution is a bit of a strawperson because it would offer only an awkward
|
||||
way to write programs. And it wouldn't even fix everything! But it's worth
|
||||
considering, if only to motivate a more complete solution.
|
||||
|
||||
Example 4 above is artificial because we could have just used `out_specs` instead
|
||||
of an `all_gather` in the body:
|
||||
|
||||
```python
|
||||
# Example 4 again
|
||||
f4 = shmap(lambda x: all_gather(x, 'i'), P('i'), P())
|
||||
|
||||
# Why didn't we just write it like this?
|
||||
f4_better = shmap(lambda x: x, P('i'), P('i'))
|
||||
```
|
||||
|
||||
The `f4_better` version doesn't have any transposition problems, since the
|
||||
transpose problems arise from collectives in the body.
|
||||
|
||||
Analogously, we could fix Example 1 by extending `out_specs` so that they can
|
||||
express summing:
|
||||
|
||||
```python
|
||||
# Example 1 again
|
||||
f1 = shmap(lambda x: psum(g(x), 'i'),
|
||||
in_specs=P('i'), out_specs=P())
|
||||
|
||||
# What if we could write an output sum like this?
|
||||
f1_better = shmap(g, in_specs=P('i'), out_specs=P(sum='i')) # sum='i' means sum over that axis
|
||||
|
||||
# Then it could transpose like this:
|
||||
t(f1_better) = shmap(t(g), in_specs=P(), out_specs=P('i'))
|
||||
t(t(f1_better)) = shmap(t(t(g)), in_specs=P('i'), P(sum='i'))
|
||||
```
|
||||
|
||||
So offering `psum`s built into `out_specs` fixes the transpose problem of
|
||||
Example 1. But it doesn't fully fix the cursed identity transpose in Example 3:
|
||||
|
||||
```python
|
||||
# Example 3 again
|
||||
cursed_identity = shmap(lambda x: x, P(), P())
|
||||
|
||||
# How it would transpose with the P-sum partial solution:
|
||||
t(cursed_identity) = shmap(lambda x: x / 8, P(), P(sum='i'))
|
||||
t(t(cursed_identity)) = shmap(lambda x: x / 8, P(), P(sum='i'))
|
||||
```
|
||||
|
||||
It's an improvement since the program doesn't continue to get bigger as we keep
|
||||
transposing, but we're still doing wasteful communication.
|
||||
|
||||
### Full solution: statically track device-varying vs device-invariant intermediates, plus new primitives
|
||||
|
||||
This solution has two components:
|
||||
1. track when values are guaranteed to be device-invariant vs device-varying
|
||||
over particular mesh axes, and
|
||||
2. decompose `psum` into a two-step process, introducing a new `pbroadcast`
|
||||
primitive, and introduce new primitives for `all_gather` and its transposes.
|
||||
|
||||
Morally, the tracking of device-invariant vs device-varying information is a
|
||||
type-level consideration. But for the expedience of our first implementation, we
|
||||
don't need to literally add the information to abstract values or jaxpr types.
|
||||
Before we get to implementation, we'll first introduce the idea using types.
|
||||
|
||||
Also to follow is a discussion of making the user API convenient and backward
|
||||
compatible. But to first introduce the idea, we'll ignore convenience and
|
||||
instead write code that is as explicit as possible.
|
||||
|
||||
#### Tracking device invariance in avals (a.k.a. avals-with-names, revived)
|
||||
|
||||
We can sometimes tell from static information alone that the values of some
|
||||
intermediate variables in the body of a `shmap` are guaranteed to be invariant
|
||||
along a mesh axis, in the sense that the function instances (and their
|
||||
corresponding devices) along the mesh axis must all be computing with the same
|
||||
value. We'll call such values device-invariant. For values that are not
|
||||
device-invariant, we'll say they're device-varying, though really we mean
|
||||
potentially device-varying from the point of view of the type system.
|
||||
|
||||
To encode device variance in types, we'll extend the syntax of types for arrays.
|
||||
We'll write things like `x:f32[3,4]{i}` to indicate that `x` is (potentially)
|
||||
device-varying along mesh axis `i` (and device-invariant over any other mesh
|
||||
axes of the `shmap`). More generally, we'll say the grammar for array type
|
||||
syntax is something like
|
||||
|
||||
```
|
||||
shaped_array ::= <dtype>[<int_literal>, ...]<device_variance_type>
|
||||
device_variance_type ::= {<axis_name>, ...}
|
||||
```
|
||||
|
||||
We'll also update the typing rules to handle device variance types:
|
||||
* for first-order primitives other than collectives
|
||||
- for multi-arity primitives, the operand device variance types must be equal
|
||||
where shapes must be equal, e.g. `mul x:f32[s1]{r1} y:f32[s2][r2]` requires
|
||||
`r1 == r2` in addition to `s1 == s2`
|
||||
- the output device variance type must be the same as the operand(s)
|
||||
* for higher-order primitives
|
||||
- we just instantiate any type variables including the device variance type
|
||||
(and checking types for equality checks their device variance types are
|
||||
equal)
|
||||
- (when performing type inference, e.g. for branches of a `cond`, we take the
|
||||
union of the sets of axis names in device variance types)
|
||||
* for first-order collectives
|
||||
- a collective can either accept a device-varying or device-invariant input
|
||||
(along a mesh axis corresponding to its axis name parameter); it's an error
|
||||
to pass a device-invariant operand to a collective which accepts
|
||||
device-varying operands and vice-versa
|
||||
- a collective can either produce a device-varying or device-invariant output
|
||||
- see the table below
|
||||
As a side benefit, whatever logic implements this type checking can subsume
|
||||
`shmap`'s "static analysis" check for whether a `shmap` body function is
|
||||
compatible with any unmapped `out_specs`.
|
||||
|
||||
Here's a table summarizing the device variance typing for collective primitives:
|
||||
|
||||
| Name | Device variance type | Example | Lowers to HLO | Transpose |
|
||||
| --- | --- | --- | --- | --- |
|
||||
| `psum2` | `Varying -> Invariant` | `y:f32[3]{j} = psum(x:f32[3]{i,j}, axis='i')` | `AllReduceSum` (communication) | `pbroadcast` |
|
||||
| `pbroadcast` | `Invariant -> Varying` | `y:f32[3]{i} = pbroadcast(x:f32[3], 'i')` | no-op (no communication) | `psum` |
|
||||
| `all_to_all` | `Varying -> Varying` | `y:f32[16]{i} = all_to_all(x:f32[16]{i}, 'i', 0, 0)` `AllToAll` (communication) | `all_to_all` |
|
||||
| `axis_index` | `() -> Varying` | `idx:i32[]{i} = axis_index('i')` | `ReplicaId` and some arithmetic (no communication) | n/a |
|
||||
| `psum_scatter` | `Varying -> Varying` | `y:f32[2]{i} = psum_scatter(x:f32[16]{i}, 'i')` | `ReduceScatterSum` (communication) | `all_gather` |
|
||||
| `all_gather` | `Varying -> Varying` | `y:f32[16]{i} = all_gather(x:f32[2]{i}, 'i')` | `AllGather` (communication) | `psum_scatter` |
|
||||
| `pscatter` | `Invariant -> Varying` | `y:f32[2]{i} = pscatter(x:f32[16], 'i')` | `lambda x: x[axis_index('i'), None]` (no communication) | `all_gather_invariant` |
|
||||
| `all_gather_invariant` | `Varying -> Invariant` | `y:f32[16] = all_gather_invariant(x:f32[2]{i}, 'i')` | `AllGather` (communication) | `pscatter` |
|
||||
|
||||
|
||||
There are some surprising things here!
|
||||
* We introduced several new primitives, including
|
||||
- `pbroadcast`, which interestingly lowers to a no-op
|
||||
- `all_gather_invariant`, which lowers to the same thing as `all_gather` but
|
||||
has a different device variance type (essentially `all_gather` has a
|
||||
`pbroadcast` fused into it, whereas `all_gather_invariant` does not)
|
||||
- `pscatter` which is the dual (transpose) of `all_gather_invariant`
|
||||
* all_gather has a device-varying result
|
||||
|
||||
Intuitively, the reason to introduce `pbroadcast` (other than to make the typing
|
||||
rules work) is so that `psum` can transpose to a physical no-op. The reason we
|
||||
need `all_gather` to have a device-varying result is so that we can transpose it
|
||||
to `psum_scatter`; if we instead left it with a device-invariant result, we
|
||||
might need a downstream `pbroadcast`, and that composition would transpose to an
|
||||
inefficient `psum` followed by slicing / `pscatter`. So instead we have a
|
||||
`pbroadcast` "fused into" the `all_gather`, thus allowing for an efficient
|
||||
transpose into `psum_scatter`. We provide `all_gather_invariant` and its
|
||||
transpose `pscatter` mainly for completeness; it's unlikely users will need it
|
||||
(it corresponds to the situation in Example 4, which is easy to write
|
||||
differently using `out_specs`).
|
||||
|
||||
Interestingly, the `psum` and `pbroadcast` transpose pair correspond to the
|
||||
`psum_idrev` and `id_psumrev` that users introduced while training LLMs with
|
||||
`pmap`.
|
||||
|
||||
#### How this system solves the inefficient transpose examples
|
||||
|
||||
Consider again the simplified motivating example:
|
||||
|
||||
```python
|
||||
# Example 1 again
|
||||
f1 = shmap(lambda x: psum(g(x), 'i'),
|
||||
in_specs=P('i'), out_specs=P())
|
||||
|
||||
# Example 1 with intermediate device variance types annotated
|
||||
@partial(shmap, in_specs=P('i'), out_specs=P())
|
||||
def f1(x: f32[3,4]{i}):
|
||||
w:f32[]{i} = g(x)
|
||||
y:f32[]{} = psum(w, 'i')
|
||||
return y
|
||||
```
|
||||
|
||||
With these new rules, the transpose is:
|
||||
|
||||
```python
|
||||
# Example 1 transpose using device variance types (go ahead and transpose this again!)
|
||||
t(f1) = shmap(lambda ybar: t(g)(pbroadcast(ybar, 'i')),
|
||||
in_specs=P(), out_specs=P('i'))
|
||||
|
||||
# Example 1 transpose with intermediate device variance types annotated
|
||||
@partial(shmap, in_specs=P('i'), out_specs=P())
|
||||
def f1_transpose(ybar: f32[]):
|
||||
wbar:f32[]{i} = pbroadcast(ybar, 'i')
|
||||
xbar:f32[3,4]{i} = transpose(g)(wbar)
|
||||
return xbar
|
||||
```
|
||||
|
||||
where evaluating the `pbroadcast` application involves no communication or FLOPs
|
||||
at all; it's a no-op. Notice that if we keep transposing the body does not grow
|
||||
in size; indeed `t(t(f1)) == f1`. Efficiency achieved!
|
||||
|
||||
And we wouldn't mess up the other examples either, so long as we `pbroadcast` to
|
||||
make the types check where needed:
|
||||
|
||||
```python
|
||||
# Example 2 rewritten with explicit pbroadcast
|
||||
f2 = shmap(lambda x, y: pbroadcast(psum(g(x), 'i'), 'i') * y,
|
||||
in_specs=(P('i'), P('i')), out_specs=P('i'))
|
||||
|
||||
# Example 2 transpose using device variance types
|
||||
t(f2, 0) = shmap(lambda y, zbar: t(g)(pbroadcast(psum(zbar * y, 'i'), 'i')),
|
||||
in_specs=(P('i'), P('i')), out_specs=P('i'))
|
||||
|
||||
|
||||
# Example 3 again
|
||||
cursed_identity = shmap(lambda x: x, P(), P())
|
||||
# Notice here the body is `f32[...] -> f32[...]`, i.e. no device varying type.
|
||||
|
||||
# Example 3 transpose using device variance types
|
||||
t(cursed_identity) = shmap(lambda x: x, P(), P())
|
||||
t(t(cursed_identity)) = shmap(lambda x: x, P(), P())
|
||||
```
|
||||
|
||||
Intuitively, in Example 1 we now only have "half the original psum", whereas in
|
||||
Example 2 we get both "halves". For Example 3 we never need any operations in
|
||||
the body at all.
|
||||
|
||||
For the `all_gather` examples, Example 4 would need to use
|
||||
`all_reduce_invariant` to have an efficient transpose (though it'd be better to
|
||||
instead use `out_specs` instead of the collective in the body):
|
||||
|
||||
```python
|
||||
# Example 4 rewritten with explicit all_reduce_invariant
|
||||
f4 = shmap(lambda x: all_gather_invariant(x, 'i'), P('i'), P())
|
||||
|
||||
# Example 4 with intermediate device variance types annotated
|
||||
@partial(shmap, P('i'), P())
|
||||
def f4(x:f32[1]{i}):
|
||||
y:f32[8]{} = all_gather_invariant(x, 'i')
|
||||
return y
|
||||
|
||||
# Example 4 transpose with intermediate device variance types annotated
|
||||
@partial(shmap, in_specs=P(), out_specs=P('i'))
|
||||
def f4_transpose(ybar:f32[8]):
|
||||
xbar:f32[1]{i} = pscatter(ybar, 'i')
|
||||
return xbar
|
||||
```
|
||||
|
||||
For Example 5, using the device-varying `all_gather` works as we'd want:
|
||||
|
||||
```python
|
||||
# Example 5 with intermediate device variance types annotated
|
||||
@partial(shmap, in_specs=(P('i'), P('i')), out_specs=P('i'))
|
||||
def f5(x:f32[1]{i}, y:f32[8]{i}):
|
||||
z:f32[8]{i} = all_gather(x, 'i')
|
||||
w:f32[8]{i} = z * y
|
||||
return w
|
||||
|
||||
# Transpose with respect to first argument
|
||||
@partial(shmap, in_specs=(P('i'), P('i')), out_specs=P('i'))
|
||||
def f5_transpose(y:f32[8]{i}, wbar:f32[8]{i}):
|
||||
zbar:f32[8]{i} = wbar * y
|
||||
xbar:f32[1]{i} = psum_scatter(zbar, 'i')
|
||||
return xbar
|
||||
```
|
||||
|
||||
### How to make the API convenient for users (and backward compatible)
|
||||
|
||||
But what user wants to write `pbroadcast`s? And what developer wants to break
|
||||
lots of existing user code involving `psum`s which are not fed into unmapped
|
||||
outputs? Not me!
|
||||
|
||||
Instead we can automatically insert the `pbroadcast`s. It's a bit analogous to how
|
||||
we do automatic rank promotion at the `jax.numpy` layer, inserting broadcasts to
|
||||
avoid rank mismatch errors in binary operators. But it's much simpler since we
|
||||
don't need to contend with shape tuples. The typical rule is: whenever we see a
|
||||
multi-arity operation where the operands disagree in their device variance
|
||||
types, take the union of operands' device variance types' axis name sets and
|
||||
insert `pbroadcast`s to lift each operand to the resulting device variance type.
|
||||
|
||||
Automatically inserting `pbroadcast`s just before they're needed may mean we apply
|
||||
the same `pbroadcast` to the same operand multiple times, creating common
|
||||
subexpressions. When we transpose, those could turn into a sum-of-`psum`s rather
|
||||
than a `psum`-of-sum. We'll rely on the compiler to clean that up as appropriate.
|
||||
If it's a problem then we could add some simple memoization to the
|
||||
`pbroadcast`-insertion pass.
|
||||
|
||||
The user API for `all_gather` will mean `all_gather_p` by default (not
|
||||
`all_gather_invariant_p`), covering the common case and meaning no `pbroadcast`s
|
||||
must be inserted.
|
||||
|
||||
We can provide an option on `shmap` to disable this automatic insertion of
|
||||
`pbroadcast`s, in which case it'll be up to the user to ensure type-correctness.
|
||||
This explicit option may be appealing to some who want to be explicit about
|
||||
where the `psum`s occur in the backward pass.
|
||||
|
||||
### How to implement the solution
|
||||
|
||||
The key to making the implementation lightweight is that **we aren't going to
|
||||
add these types to avals or jaxprs**. At least, not at first. That can be
|
||||
expensive because it requires updating the rest of JAX, e.g. all consumers of
|
||||
avals and jaxprs may need to handle the new types. We're not falling for that
|
||||
again!
|
||||
|
||||
Instead we're going to keep these extended types as metadata internal to
|
||||
`shmap`, just like the current "replication checking for `out_specs`" machinery
|
||||
is internal to `shmap`. Indeed this solution amounts to a relatively small
|
||||
extension to that existing machinery: it was already tracking the same
|
||||
information; now we're just adding the `pbroadcast`s.
|
||||
|
||||
We have at least two options for where to perform the `pbroadcast` insertion:
|
||||
1. just before transposition, in the transpose rule, where we have a jaxpr of
|
||||
the computation to be transposed;
|
||||
2. in every `shmap` body, whether eagerly executed or staged out, like the
|
||||
current "replication checking for `out_specs`" machinery.
|
||||
The former may end up being easier since we only have to handle the jaxpr case,
|
||||
and only linear primitives. But we'll start by trying the latter so the
|
||||
implementation here is a strict revision/extension to the existing
|
||||
replication-checking logic.
|
||||
|
||||
## Appendix: defining and motivating maps with unmapped inputs and outputs
|
||||
|
||||
For concreteness, we'll mostly focus on `shmap`, though these same ideas apply
|
||||
to e.g. `pmap` and probably `xmap`.
|
||||
|
||||
An argument/input is _unmapped_ along a mesh axis when the corresponding entry
|
||||
of `in_specs` doesn't mention that mesh axis's name. Logically it means that
|
||||
each function instance along that mesh axis gets the same value for the
|
||||
argument. To the caller, each operand is sliced according to the mesh axes over
|
||||
which the operand is mapped, whereas there is no slicing for mesh axes over
|
||||
which the operand is unmapped.
|
||||
|
||||
An output is _unmapped_ along a mesh axis when the corresponding entry of
|
||||
`out_specs` doesn't mention that mesh axis's name. Logically it means each
|
||||
function instance along that mesh axis must return the same value. To the
|
||||
caller, each result of the `shmap` is formed by concatenating the return values
|
||||
of every function instance along which the outputs are mapped, whereas for mesh
|
||||
axes over which the output is unmapped only one copy of the value is used.
|
||||
|
||||
See [the `shmap`
|
||||
JEP](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html) for examples
|
||||
of unmapped inputs and outputs. For comparison, in `vmap` unmapped
|
||||
inputs/outputs are indicated by using `in_axes` / `out_axes` of `None` (rather
|
||||
than an `int`).
|
||||
|
||||
Here are reasons we like unmapped inputs and outputs for `shmap`:
|
||||
* **Same expressiveness as `pjit`.** Anything `pjit` can do, the `shmap` escape
|
||||
hatch should be able to do too. Or else we'd have a lacking escape hatch! If
|
||||
we didnt have unmapped outputs in `shmap` then we couldn't express the same
|
||||
batch-parallel loss function computations as `pjit`.
|
||||
* **Closed-over inputs.** Closed-over inputs essentially correspond to unmapped
|
||||
inputs, and...
|
||||
* **Closure under transposition.** Once we have unmapped inputs, it's natural to
|
||||
be able to transpose to unmapped outputs.
|
||||
|
||||
So unmapped outputs are both canonical and useful!
|
@ -48,6 +48,7 @@ Then create a pull request that adds a file named
|
||||
12049: Type Annotation Roadmap for JAX <12049-type-annotations>
|
||||
14273: `shard_map` (`shmap`) for simple per-device code <14273-shard-map>
|
||||
15856: `jax.extend`, an extensions module <15856-jex>
|
||||
17111: Efficient transposition of `shard_map` (and other maps) <17111-shmap-transpose>
|
||||
|
||||
|
||||
Several early JEPs were converted in hindsight from other documentation,
|
||||
|
@ -1084,8 +1084,7 @@ def _why_alive_container_info(container, obj_id) -> str:
|
||||
|
||||
|
||||
@contextmanager
|
||||
def new_main(trace_type: type[Trace],
|
||||
dynamic: bool = False,
|
||||
def new_main(trace_type: type[Trace], dynamic: bool = False,
|
||||
**payload) -> Generator[MainTrace, None, None]:
|
||||
# See comments in https://github.com/google/jax/pull/3370
|
||||
stack = thread_local_state.trace_state.trace_stack
|
||||
@ -1111,6 +1110,20 @@ def new_main(trace_type: type[Trace],
|
||||
leaked_tracers = maybe_find_leaked_tracers(t())
|
||||
if leaked_tracers: raise leaked_tracer_error("trace", t(), leaked_tracers)
|
||||
|
||||
@contextmanager
|
||||
def new_dynamic(level: int) -> Generator[None, None, None]:
|
||||
stack = thread_local_state.trace_state.trace_stack
|
||||
prev_dynamic, stack.dynamic = stack.dynamic, stack.stack[level]
|
||||
_update_thread_local_jit_state(stack.dynamic)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
stack.dynamic = prev_dynamic
|
||||
_update_thread_local_jit_state(stack.dynamic)
|
||||
|
||||
def dynamic_level() -> int:
|
||||
return thread_local_state.trace_state.trace_stack.dynamic.level
|
||||
|
||||
@contextmanager
|
||||
def new_base_main(trace_type: type[Trace],
|
||||
**payload) -> Generator[MainTrace, None, None]:
|
||||
|
@ -853,6 +853,7 @@ def _custom_vjp_call_jaxpr_jvp(
|
||||
tangents_out = ad.custom_lin_p.bind(
|
||||
*res, *args_dot, num_res=res_tree.num_leaves, bwd=bwd,
|
||||
out_avals=avals_out, symbolic_zeros=symbolic_zeros)
|
||||
tangents_out = map(lax.tie_p.bind, primals_out, tangents_out)
|
||||
tangents_out = map(ad.recast_to_float0, primals_out, tangents_out)
|
||||
return primals_out, tangents_out
|
||||
ad.primitive_jvps[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_jvp
|
||||
|
@ -403,6 +403,7 @@ class JVPTrace(Trace):
|
||||
tangents_out = custom_lin_p.bind(
|
||||
*res, *tangents_in, num_res=res_tree.num_leaves, bwd=bwd,
|
||||
out_avals=avals_out, symbolic_zeros=symbolic_zeros)
|
||||
tangents_out = map(jax._src.lax.lax.tie_p.bind, primals_out, tangents_out)
|
||||
tangents_out = map(recast_to_float0, primals_out, tangents_out)
|
||||
return map(partial(JVPTracer, self), primals_out, tangents_out)
|
||||
|
||||
|
@ -4966,6 +4966,17 @@ def _empty_lower(ctx, *, dtype):
|
||||
mlir.register_lowering(empty_p, _empty_lower)
|
||||
|
||||
|
||||
tie_p = core.Primitive('tie')
|
||||
tie_p.def_impl(lambda x, y: y)
|
||||
tie_p.def_abstract_eval(lambda x, y: y)
|
||||
mlir.register_lowering(tie_p, lambda ctx, x, y: [y])
|
||||
ad.primitive_jvps[tie_p] = \
|
||||
lambda primals, tangents: (tie_p.bind(*primals), tangents[-1])
|
||||
ad.primitive_transposes[tie_p] = lambda ct, x, _: [None, ct]
|
||||
pe.def_trivial_padding(tie_p)
|
||||
batching.defvectorized(tie_p)
|
||||
|
||||
|
||||
class BIntRules:
|
||||
@staticmethod
|
||||
def physical_element_aval(dtype) -> core.ShapedArray:
|
||||
|
@ -812,6 +812,7 @@ batching.axis_primitive_batchers[psum_p] = \
|
||||
partial(_batched_reduction_collective, psum_p, lambda v, axis_size: axis_size * v)
|
||||
core.axis_substitution_rules[psum_p] = partial(_subst_all_names_in_param, 'axes')
|
||||
|
||||
|
||||
# We set a special bind rule for psum so that psum(1, 'i') can be evaluated at
|
||||
# tracing time.
|
||||
@psum_p.def_custom_bind
|
||||
|
@ -1465,6 +1465,8 @@ tf_not_yet_impl = [
|
||||
"pmin",
|
||||
"ppermute",
|
||||
"psum",
|
||||
"psum2",
|
||||
"pbroadcast",
|
||||
"pmax",
|
||||
"pgather",
|
||||
"reduce_scatter",
|
||||
@ -3449,6 +3451,9 @@ def _reduce_precision(x, *, exponent_bits, mantissa_bits):
|
||||
|
||||
tf_impl[lax.reduce_precision_p] = _reduce_precision
|
||||
|
||||
tf_impl[lax_internal.tie_p] = lambda x, y: y
|
||||
|
||||
|
||||
def _register_checkpoint_pytrees():
|
||||
"""Registers TF custom container types as pytrees."""
|
||||
m = tf.Module()
|
||||
|
@ -18,7 +18,7 @@ import enum
|
||||
from functools import partial
|
||||
import inspect
|
||||
import itertools as it
|
||||
import math
|
||||
from math import prod
|
||||
import operator as op
|
||||
from typing import Any, Callable, Optional, TypeVar, Union
|
||||
|
||||
@ -104,7 +104,7 @@ def _shard_map(f: Callable, mesh: Mesh, in_specs: Specs,
|
||||
raise e('shard_map in_specs') from None
|
||||
_check_specs_vs_args(f, mesh, in_tree, in_specs, in_specs_flat, args_flat)
|
||||
in_names_flat = tuple(map(_canonicalize_spec, in_specs_flat))
|
||||
flat_fun, out_tree = flatten_fun_nokwargs(fun, in_tree)
|
||||
fun, out_tree = flatten_fun_nokwargs(fun, in_tree)
|
||||
|
||||
@memoize
|
||||
def out_names_thunk():
|
||||
@ -119,10 +119,15 @@ def _shard_map(f: Callable, mesh: Mesh, in_specs: Specs,
|
||||
e, *_ = prefix_errors(out_specs_, dummy)
|
||||
raise e('shard_map out_specs') from None
|
||||
return tuple(map(_canonicalize_spec, out_specs_flat))
|
||||
|
||||
if rewrite := check_rep:
|
||||
fun = _efficient_transpose_rewrite(fun, mesh, in_names_flat, out_names_thunk)
|
||||
|
||||
try:
|
||||
out_flat = shard_map_p.bind(
|
||||
flat_fun, *args_flat, mesh=mesh, in_names=in_names_flat,
|
||||
out_names_thunk=out_names_thunk, check_rep=check_rep, auto=auto)
|
||||
fun, *args_flat, mesh=mesh, in_names=in_names_flat,
|
||||
out_names_thunk=out_names_thunk, check_rep=check_rep, rewrite=rewrite,
|
||||
auto=auto)
|
||||
except _SpecError as e:
|
||||
fails, = e.args
|
||||
if not callable(out_specs):
|
||||
@ -135,12 +140,12 @@ def _shard_map(f: Callable, mesh: Mesh, in_specs: Specs,
|
||||
except _RepError as e:
|
||||
fails, = e.args
|
||||
if not callable(out_specs):
|
||||
msg = _rep_error(f, mesh, out_tree(), out_specs, fails)
|
||||
msg = _inout_rep_error(f, mesh, out_tree(), out_specs, fails)
|
||||
raise ValueError(msg) from None
|
||||
return tree_unflatten(out_tree(), out_flat)
|
||||
return wrapped
|
||||
|
||||
# Internally use AxisNames = Dict[int, Tuple[AxisName, ...]], not PartitionSpecs
|
||||
# Internally use AxisNames = dict[int, tuple[AxisName, ...]], not PartitionSpecs
|
||||
AxisNames = dict[int, tuple[AxisName, ...]] # TODO(mattjj): make it hashable
|
||||
def _canonicalize_spec(spec: PartitionSpec) -> AxisNames:
|
||||
if isinstance(spec, PartitionSpec):
|
||||
@ -183,7 +188,7 @@ def _check_specs_vs_args(
|
||||
msg = _spec_rank_error(SpecErrorType.input, f, in_tree, in_specs, fail)
|
||||
raise ValueError(msg)
|
||||
in_names_flat = tuple(map(_canonicalize_spec, in_specs_flat))
|
||||
fail = [a if any(a.shape[d] % math.prod(mesh.shape[n] for n in ns)
|
||||
fail = [a if any(a.shape[d] % prod(mesh.shape[n] for n in ns)
|
||||
for d, ns in names.items()) else no_fail
|
||||
for a, names in zip(in_avals, in_names_flat)]
|
||||
if any(f is not no_fail for f in fail):
|
||||
@ -239,10 +244,10 @@ def _spec_divisibility_error(
|
||||
f"parameter '{list(ba.arguments.keys())[arg_key.idx]}',")
|
||||
names = _canonicalize_spec(spec)
|
||||
for d, ns in names.items():
|
||||
if aval.shape[d] % math.prod(mesh.shape[n] for n in ns):
|
||||
if aval.shape[d] % prod(mesh.shape[n] for n in ns):
|
||||
axis = f"axes {ns}" if len(ns) > 1 else f"axis '{ns[0]}'"
|
||||
total = 'total ' if len(ns) > 1 else ''
|
||||
sz = math.prod(mesh.shape[n] for n in ns)
|
||||
sz = prod(mesh.shape[n] for n in ns)
|
||||
msgs.append(
|
||||
f"* args{keystr(fail_key)} of shape {aval.str_short()}{extra} "
|
||||
f"corresponds to in_specs{keystr(spec_key)} of value {spec}, "
|
||||
@ -263,7 +268,7 @@ def _spec_divisibility_error(
|
||||
f"padding the input and adapting '{fun_name}' appropriately.")
|
||||
return msg
|
||||
|
||||
def _rep_error(f: Callable, mesh: Mesh, tree: PyTreeDef, specs: Specs,
|
||||
def _inout_rep_error(f: Callable, mesh: Mesh, tree: PyTreeDef, specs: Specs,
|
||||
fails: list[set | NoFail]) -> str:
|
||||
fun_name = getattr(f, '__name__', str(f))
|
||||
msgs = []
|
||||
@ -332,10 +337,11 @@ class ShardMapPrimitive(core.Primitive):
|
||||
def bind(self, fun: lu.WrappedFun, *args: MaybeTracer, mesh: Mesh,
|
||||
in_names: tuple[AxisNames, ...],
|
||||
out_names_thunk: Callable[[], tuple[AxisNames, ...]],
|
||||
check_rep: bool, auto: frozenset[AxisName]) -> Sequence[MaybeTracer]:
|
||||
check_rep: bool, rewrite: bool, auto: frozenset[AxisName]
|
||||
) -> Sequence[MaybeTracer]:
|
||||
top_trace = core.find_top_trace(args)
|
||||
fun, env_todo = process_env_traces(fun, top_trace.level, mesh, in_names,
|
||||
out_names_thunk, check_rep, auto)
|
||||
out_names_thunk, check_rep, rewrite, auto)
|
||||
|
||||
@as_hashable_function(closure=out_names_thunk)
|
||||
def new_out_names_thunk():
|
||||
@ -348,7 +354,8 @@ class ShardMapPrimitive(core.Primitive):
|
||||
tracers = map(top_trace.full_raise, args)
|
||||
outs = top_trace.process_shard_map( # pytype: disable=attribute-error
|
||||
shard_map_p, fun, tracers, mesh=mesh, in_names=in_names,
|
||||
out_names_thunk=new_out_names_thunk, check_rep=check_rep, auto=auto)
|
||||
out_names_thunk=new_out_names_thunk, check_rep=check_rep,
|
||||
rewrite=rewrite, auto=auto)
|
||||
todos, _ = env_todo()
|
||||
return map(core.full_lower, core.apply_todos(todos, outs))
|
||||
|
||||
@ -364,7 +371,7 @@ shard_map_p = ShardMapPrimitive('shard_map')
|
||||
|
||||
@lu.transformation_with_aux
|
||||
def process_env_traces(level: int, mesh, in_names, out_names_thunk, check_rep,
|
||||
auto, *args: Any):
|
||||
rewrite, auto, *args: Any):
|
||||
outs = yield args, {}
|
||||
todos, out_names_transforms = [], []
|
||||
while True:
|
||||
@ -377,7 +384,7 @@ def process_env_traces(level: int, mesh, in_names, out_names_thunk, check_rep,
|
||||
trace = ans._trace.main.with_cur_sublevel()
|
||||
outs = map(trace.full_raise, outs)
|
||||
outs, (todo, xform) = trace.post_process_shard_map(
|
||||
outs, mesh, in_names, out_names_thunk, check_rep, auto)
|
||||
outs, mesh, in_names, out_names_thunk, check_rep, rewrite, auto)
|
||||
todos.append(todo)
|
||||
out_names_transforms.append(xform)
|
||||
yield outs, (tuple(todos), tuple(out_names_transforms))
|
||||
@ -390,19 +397,19 @@ def _shard_map_staging(
|
||||
in_names: tuple[AxisNames, ...],
|
||||
out_names_thunk: Callable[[], tuple[AxisNames, ...]],
|
||||
check_rep: bool,
|
||||
rewrite: bool,
|
||||
auto: frozenset,
|
||||
) -> Sequence[pe.DynamicJaxprTracer]:
|
||||
in_avals = [t.aval for t in in_tracers]
|
||||
in_avals_ = map(partial(_shard_aval, mesh), in_names, in_avals)
|
||||
main = trace.main
|
||||
with core.new_sublevel(), core.extend_axis_env_nd(mesh.shape.items()):
|
||||
jaxpr, out_avals_generic, consts = pe.trace_to_subjaxpr_dynamic(
|
||||
f, main, in_avals_)
|
||||
out_avals_ = map(_check_shapedarray, out_avals_generic)
|
||||
jaxpr, genavals, consts = pe.trace_to_subjaxpr_dynamic(f, main, in_avals_)
|
||||
out_avals_ = map(_check_shapedarray, genavals)
|
||||
_check_names(out_names_thunk(), out_avals_)
|
||||
if check_rep:
|
||||
in_rep = map(partial(_in_names_to_rep, mesh), in_names)
|
||||
out_rep = _output_rep(mesh, jaxpr, in_rep)
|
||||
if check_rep:
|
||||
out_rep = _check_rep(mesh, jaxpr, in_rep)
|
||||
_check_reps(mesh, out_names_thunk(), out_rep)
|
||||
out_avals = map(partial(_unshard_aval, mesh), out_names_thunk(), out_avals_)
|
||||
source_info = source_info_util.current()
|
||||
@ -415,13 +422,80 @@ def _shard_map_staging(
|
||||
jaxpr = pe.convert_constvars_jaxpr(jaxpr)
|
||||
params = dict(mesh=mesh, in_names=in_names_staged,
|
||||
out_names=tuple(out_names_thunk()), jaxpr=jaxpr,
|
||||
check_rep=check_rep, auto=auto)
|
||||
check_rep=check_rep, rewrite=rewrite, auto=auto)
|
||||
eqn = pe.new_jaxpr_eqn([*constvars, *invars], outvars, prim, params,
|
||||
jaxpr.effects, source_info)
|
||||
trace.frame.add_eqn(eqn)
|
||||
return out_tracers
|
||||
pe.DynamicJaxprTrace.process_shard_map = _shard_map_staging
|
||||
|
||||
|
||||
Val = Any
|
||||
|
||||
# TODO(mattjj): caching
|
||||
def _replication_rewrite_match(
|
||||
mesh: Mesh, jaxpr: core.ClosedJaxpr, in_rep: Sequence[set[AxisName]],
|
||||
out_rep_dst: Sequence[set[AxisName]],
|
||||
) -> core.ClosedJaxpr:
|
||||
f = lu.wrap_init(partial(_rep_rewrite, mesh, jaxpr, in_rep))
|
||||
f = _match_rep(f, mesh, out_rep_dst)
|
||||
with core.extend_axis_env_nd(mesh.shape.items()):
|
||||
jaxpr_, _, consts = pe.trace_to_jaxpr_dynamic(f, jaxpr.in_avals)
|
||||
return core.ClosedJaxpr(jaxpr_, consts)
|
||||
|
||||
@lu.transformation
|
||||
def _match_rep(mesh: Mesh, out_rep_dst: Sequence[set[AxisName]], *args):
|
||||
out_vals, out_reps = yield args, {}
|
||||
_check_reps2(mesh, out_rep_dst, out_reps)
|
||||
out_vals = [pbroadcast(x, tuple(n for n in src if n not in dst)) if src - dst
|
||||
else x for x, src, dst in zip(out_vals, out_reps, out_rep_dst)]
|
||||
yield out_vals
|
||||
|
||||
|
||||
def _rep_rewrite(
|
||||
mesh: Mesh, jaxpr_: core.ClosedJaxpr,
|
||||
in_rep: Sequence[set[AxisName]], *args: Val,
|
||||
) -> tuple[tuple[Val], tuple[set[AxisName]]]:
|
||||
jaxpr, consts = jaxpr_.jaxpr, jaxpr_.consts
|
||||
|
||||
env: dict[core.Var, tuple[Val, set[AxisName]]] = {}
|
||||
|
||||
def read(x: core.Atom) -> tuple[Val, set[AxisName]]:
|
||||
return env[x] if isinstance(x, core.Var) else (x.val, set(mesh.axis_names))
|
||||
|
||||
def write(v: core.Var, val: Val, rep: set[AxisName]) -> None:
|
||||
env[v] = (val, rep)
|
||||
|
||||
map(write, jaxpr.constvars, consts, [set(mesh.axis_names)] * len(consts))
|
||||
map(write, jaxpr.invars, args, in_rep)
|
||||
for e in jaxpr.eqns:
|
||||
rule = _rewrite_rules.get(e.primitive, partial(_rule_missing, e.primitive))
|
||||
in_vals, in_reps = unzip2(map(read, e.invars))
|
||||
out_vals, out_reps = rule(mesh, in_reps, *in_vals, **e.params)
|
||||
map(write, e.outvars, out_vals, out_reps)
|
||||
out_vals, out_reps = unzip2(map(read, jaxpr.outvars))
|
||||
return out_vals, out_reps
|
||||
|
||||
def _rule_missing(prim: core.Primitive, *_, **__):
|
||||
raise NotImplementedError(
|
||||
f"No replication rule for {prim}. As a workaround, pass the "
|
||||
"`check_rep=False` argument to `shard_map`. To get this fixed, open an "
|
||||
"issue at https://github.com/google/jax/issues")
|
||||
|
||||
def _replication_rewrite_nomatch(
|
||||
mesh: Mesh, jaxpr: core.ClosedJaxpr, in_rep: Sequence[set[AxisName]],
|
||||
) -> tuple[core.ClosedJaxpr, list[set[AxisName]]]:
|
||||
f = lu.wrap_init(partial(_rep_rewrite, mesh, jaxpr, in_rep))
|
||||
f, out_rep = _grab_out_rep(f)
|
||||
with core.extend_axis_env_nd(mesh.shape.items()):
|
||||
jaxpr_, _, consts = pe.trace_to_jaxpr_dynamic(f, jaxpr.in_avals)
|
||||
return core.ClosedJaxpr(jaxpr_, consts), list(out_rep())
|
||||
|
||||
@lu.transformation_with_aux
|
||||
def _grab_out_rep(*args):
|
||||
yield (yield args, {})
|
||||
|
||||
|
||||
def _check_shapedarray(aval: core.AbstractValue) -> core.ShapedArray:
|
||||
assert isinstance(aval, core.ShapedArray)
|
||||
return aval
|
||||
@ -429,7 +503,7 @@ def _check_shapedarray(aval: core.AbstractValue) -> core.ShapedArray:
|
||||
def _shard_aval(mesh: Mesh, names: AxisNames, aval: core.AbstractValue
|
||||
) -> core.AbstractValue:
|
||||
if isinstance(aval, core.ShapedArray):
|
||||
return aval.update(tuple(sz // math.prod(mesh.shape[n] for n in names.get(i, ()))
|
||||
return aval.update(tuple(sz // prod(mesh.shape[n] for n in names.get(i, ()))
|
||||
for i, sz in enumerate(aval.shape)))
|
||||
else:
|
||||
raise NotImplementedError # TODO(mattjj): add table with handlers
|
||||
@ -437,7 +511,7 @@ def _shard_aval(mesh: Mesh, names: AxisNames, aval: core.AbstractValue
|
||||
def _unshard_aval(mesh: Mesh, names: AxisNames, aval: core.AbstractValue
|
||||
) -> core.AbstractValue:
|
||||
if isinstance(aval, core.ShapedArray):
|
||||
return aval.update(tuple(sz * math.prod(mesh.shape[n] for n in names.get(i, ()))
|
||||
return aval.update(tuple(sz * prod(mesh.shape[n] for n in names.get(i, ()))
|
||||
for i, sz in enumerate(aval.shape)),
|
||||
named_shape={k: v for k, v in aval.named_shape.items()
|
||||
if k not in mesh.shape})
|
||||
@ -447,7 +521,7 @@ def _unshard_aval(mesh: Mesh, names: AxisNames, aval: core.AbstractValue
|
||||
# Type-checking
|
||||
|
||||
def _shard_map_typecheck(_, *in_atoms, jaxpr, mesh, in_names, out_names,
|
||||
check_rep, auto):
|
||||
check_rep, rewrite, auto):
|
||||
del auto # TODO(mattjj,parkers): check
|
||||
for v, x, in_name in zip(jaxpr.invars, in_atoms, in_names):
|
||||
if not core.typecompat(v.aval, _shard_aval(mesh, in_name, x.aval)):
|
||||
@ -457,11 +531,11 @@ def _shard_map_typecheck(_, *in_atoms, jaxpr, mesh, in_names, out_names,
|
||||
core.check_jaxpr(jaxpr)
|
||||
if check_rep:
|
||||
in_rep = map(partial(_in_names_to_rep, mesh), in_names)
|
||||
out_rep = _output_rep(mesh, jaxpr, in_rep)
|
||||
out_rep = _check_rep(mesh, jaxpr, in_rep)
|
||||
for rep, dst in zip(out_rep, out_names):
|
||||
if not _valid_repeats(mesh, rep, dst):
|
||||
raise core.JaxprTypeError("shard_map can't prove output is sufficiently "
|
||||
"replicated")
|
||||
raise core.JaxprTypeError("shard_map can't prove output is "
|
||||
"sufficiently replicated")
|
||||
out_avals_sharded = [x.aval for x in jaxpr.outvars]
|
||||
out_avals = map(partial(_unshard_aval, mesh), out_names, out_avals_sharded)
|
||||
return out_avals, jaxpr.effects
|
||||
@ -470,7 +544,7 @@ core.custom_typechecks[shard_map_p] = _shard_map_typecheck
|
||||
def _in_names_to_rep(mesh: Mesh, names: AxisNames) -> set[AxisName]:
|
||||
return set(mesh.axis_names) - {n for ns in names.values() for n in ns}
|
||||
|
||||
def _output_rep(mesh: Mesh, jaxpr: core.Jaxpr, in_rep: Sequence[set[AxisName]],
|
||||
def _check_rep(mesh: Mesh, jaxpr: core.Jaxpr, in_rep: Sequence[set[AxisName]],
|
||||
) -> Sequence[set[AxisName]]:
|
||||
env: dict[core.Var, set[AxisName]] = {}
|
||||
|
||||
@ -484,7 +558,7 @@ def _output_rep(mesh: Mesh, jaxpr: core.Jaxpr, in_rep: Sequence[set[AxisName]],
|
||||
map(write, jaxpr.invars, in_rep)
|
||||
last_used = core.last_used(jaxpr)
|
||||
for e in jaxpr.eqns:
|
||||
rule = _rep_rules.get(e.primitive, partial(_rep_rule, e.primitive))
|
||||
rule = _check_rules.get(e.primitive, partial(_rule_missing, e.primitive))
|
||||
out_rep = rule(mesh, *map(read, e.invars), **e.params)
|
||||
if e.primitive.multiple_results:
|
||||
out_rep = [out_rep] * len(e.outvars) if type(out_rep) is set else out_rep
|
||||
@ -500,8 +574,8 @@ def _valid_repeats(mesh: Mesh, rep: set[AxisName], dst: AxisNames) -> bool:
|
||||
# Lowering
|
||||
|
||||
def _shard_map_lowering(ctx, *in_nodes, jaxpr, mesh, in_names, out_names,
|
||||
check_rep, auto):
|
||||
del check_rep
|
||||
check_rep, rewrite, auto):
|
||||
del check_rep, rewrite
|
||||
in_avals_ = [v.aval for v in jaxpr.invars]
|
||||
out_avals_ = [x.aval for x in jaxpr.outvars]
|
||||
in_nodes_ = map(partial(_xla_shard, ctx, mesh, auto), in_names, ctx.avals_in,
|
||||
@ -557,7 +631,7 @@ def _pspec_mhlo_attrs(names: AxisNames, aval: core.AbstractValue) -> str:
|
||||
# Eager evaluation
|
||||
|
||||
def _shard_map_impl(trace, prim, fun, args, *, mesh, in_names, out_names_thunk,
|
||||
check_rep, auto):
|
||||
check_rep, rewrite, auto):
|
||||
if auto: raise NotImplementedError
|
||||
del prim, auto
|
||||
args = map(partial(_unmatch_spec, mesh), in_names, args)
|
||||
@ -572,8 +646,10 @@ def _shard_map_impl(trace, prim, fun, args, *, mesh, in_names, out_names_thunk,
|
||||
del main, t, in_tracers, ans, out_tracers
|
||||
out_avals = [core.mapped_aval(x.shape[0], 0, core.get_aval(x)) for x in outs_]
|
||||
_check_names(out_names_thunk(), out_avals) # pytype: disable=wrong-arg-types
|
||||
if check_rep: _check_reps(mesh, out_names_thunk(), out_rep)
|
||||
return map(partial(_match_spec, mesh), out_rep, out_names_thunk(), outs_)
|
||||
if check_rep:
|
||||
_check_reps(mesh, out_names_thunk(), out_rep)
|
||||
return map(partial(_match_spec, mesh, check_rep), out_rep, out_names_thunk(),
|
||||
outs_)
|
||||
core.EvalTrace.process_shard_map = _shard_map_impl
|
||||
|
||||
def _names_to_pspec(names: AxisNames) -> PartitionSpec:
|
||||
@ -587,7 +663,7 @@ def _unmatch_spec(mesh: Mesh, src: AxisNames, x: JaxType) -> JaxType:
|
||||
def _unmatch(mesh, src_tup, x):
|
||||
src = _names_to_pspec(dict(src_tup))
|
||||
dst = P(mesh.axis_names)
|
||||
return shard_map(_add_singleton, mesh, (src,), dst)(x)
|
||||
return shard_map(_add_singleton, mesh, (src,), dst, check_rep=False)(x)
|
||||
|
||||
def _check_names(names: Sequence[AxisNames], avals: Sequence[core.ShapedArray]
|
||||
) -> None:
|
||||
@ -602,14 +678,21 @@ def _check_reps(mesh, names, reps):
|
||||
if any(f is not no_fail for f in fail): raise _RepError(fail)
|
||||
class _RepError(Exception): pass
|
||||
|
||||
def _match_spec(mesh: Mesh, rep: set[AxisName], dst: AxisNames, x: JaxType
|
||||
) -> JaxType:
|
||||
with core.eval_context():
|
||||
return jax.jit(HashablePartial(_match, mesh, tuple(dst.items())))(x)
|
||||
def _check_reps2(mesh, reps_dest, reps):
|
||||
fail = [src if not dst.issubset(src) else no_fail
|
||||
for dst, src in zip(reps_dest, reps)]
|
||||
if any(f is not no_fail for f in fail): raise _RepError(fail)
|
||||
|
||||
def _match(mesh, dst_tup, x):
|
||||
def _match_spec(mesh: Mesh, check_rep: bool,
|
||||
rep: set[AxisName], dst: AxisNames, x: JaxType) -> JaxType:
|
||||
fn = HashablePartial(_match, mesh, check_rep, tuple(dst.items()))
|
||||
with core.eval_context():
|
||||
return jax.jit(fn)(x)
|
||||
|
||||
def _match(mesh, check_rep, dst_tup, x):
|
||||
src = P(mesh.axis_names)
|
||||
dst = _names_to_pspec(dict(dst_tup))
|
||||
# TODO put back (?) needed for rep checking in eager? for now test rewrite
|
||||
return shard_map(_rem_singleton, mesh, (src,), dst, check_rep=False)(x)
|
||||
|
||||
def _rem_singleton(x): return x.reshape(x.shape[1:])
|
||||
@ -640,7 +723,7 @@ class ShardMapTrace(core.Trace):
|
||||
f = HashablePartial(_prim_applier, prim, tuple(params.items()), self.mesh)
|
||||
with core.eval_context(), jax.disable_jit(False):
|
||||
out_vals = jax.jit(f)(*in_vals)
|
||||
rep_rule = _rep_rules.get(prim, partial(_rep_rule, prim))
|
||||
rep_rule = _check_rules.get(prim, partial(_rule_missing, prim))
|
||||
out_rep = rep_rule(self.mesh, *in_rep, **params) if self.check else set()
|
||||
if prim.multiple_results:
|
||||
out_rep = [out_rep] * len(out_vals) if type(out_rep) is set else out_rep
|
||||
@ -758,92 +841,312 @@ def _device_put_eager_rule(mesh, x, *, src, device):
|
||||
f"shard_map-decorated functions, but got device {device}")
|
||||
eager_rules[dispatch.device_put_p] = _device_put_eager_rule
|
||||
|
||||
# Static replication checking
|
||||
# New primitives for efficient transposition
|
||||
|
||||
def _rep_rule(prim: core.Primitive, mesh: Mesh, *in_rep: set[AxisName],
|
||||
**params: Any) -> set[AxisName] | list[set[AxisName]]:
|
||||
raise NotImplementedError(
|
||||
f"No replication rule for {prim}. As a workaround, pass the "
|
||||
"`check_rep=False` argument to `shard_map`. To get this fixed, open an "
|
||||
"issue at https://github.com/google/jax/issues")
|
||||
# psum2_p is like psum_p except has a different transpose, so mostly copied:
|
||||
psum2_p = core.AxisPrimitive('psum2')
|
||||
psum2_p.multiple_results = True
|
||||
psum2_p.def_impl(lax_parallel.psum_p.impl)
|
||||
psum2_p.def_effectful_abstract_eval(lax_parallel.psum_p.abstract_eval)
|
||||
mlir.register_lowering(psum2_p, mlir._lowerings[lax_parallel.psum_p])
|
||||
batching.primitive_batchers[psum2_p] = partial(lax_parallel._reduction_batcher, psum2_p)
|
||||
batching.axis_primitive_batchers[psum2_p] = \
|
||||
partial(lax_parallel._batched_reduction_collective, psum2_p,
|
||||
lambda v, axis_size: axis_size * v)
|
||||
core.axis_substitution_rules[psum2_p] = \
|
||||
partial(lax_parallel._subst_all_names_in_param, 'axes')
|
||||
def _psum2_transpose_rule(cts, *args, axes, axis_index_groups):
|
||||
del args
|
||||
return pbroadcast_p.bind(*cts, axes=axes, axis_index_groups=axis_index_groups)
|
||||
ad.deflinear2(psum2_p, _psum2_transpose_rule)
|
||||
|
||||
_rep_rules: dict[core.Primitive, Callable] = {}
|
||||
register_rule = lambda prim: lambda rule: _rep_rules.setdefault(prim, rule)
|
||||
register_standard = lambda prim: _rep_rules.setdefault(prim, _standard_rep_rule)
|
||||
# pbroadcast_p is exactly the transpose of psum2_p
|
||||
def pbroadcast(x, axis_name):
|
||||
axes = (axis_name,) if not isinstance(axis_name, tuple) else axis_name
|
||||
xs, treedef = tree_flatten(x)
|
||||
ys = pbroadcast_p.bind(*xs, axes=axes, axis_index_groups=None)
|
||||
return tree_unflatten(treedef, ys)
|
||||
pbroadcast_p = core.AxisPrimitive('pbroadcast')
|
||||
pbroadcast_p.multiple_results = True
|
||||
pbroadcast_p.def_impl(lambda *args, axes, axis_index_groups: args)
|
||||
pbroadcast_p.def_abstract_eval(lambda *args, axes, axis_index_groups: args)
|
||||
mlir.register_lowering(pbroadcast_p, lambda ctx, *x, axes, axis_index_groups: x)
|
||||
def _pbroadcast_batcher(vals_in, dims_in, *, axes, axis_index_groups):
|
||||
if any(type(axis) is int for axis in axes): raise NotImplementedError
|
||||
vals_out = pbroadcast_p.bind(*vals_in, axes=axes,
|
||||
axis_index_groups=axis_index_groups)
|
||||
return vals_out, dims_in
|
||||
batching.primitive_batchers[pbroadcast_p] = _pbroadcast_batcher
|
||||
def _pbroadcast_axis_batcher(size, name, trace_type, vals_in, dims_in, *, axes,
|
||||
groups):
|
||||
raise NotImplementedError # vmap with axis name involved in this primitive
|
||||
batching.axis_primitive_batchers[pbroadcast_p] = _pbroadcast_axis_batcher
|
||||
core.axis_substitution_rules[pbroadcast_p] = \
|
||||
partial(lax_parallel._subst_all_names_in_param, 'axes')
|
||||
ad.deflinear2(pbroadcast_p,
|
||||
lambda cts, *_, axes, axis_index_groups:
|
||||
psum2_p.bind(*cts, axes=axes, axis_index_groups=axis_index_groups))
|
||||
|
||||
# Rewrite rules and static replication checking for efficient transposition
|
||||
|
||||
_rewrite_rules: dict[core.Primitive, Callable] = {}
|
||||
register_rewrite = lambda prim: lambda r: _rewrite_rules.setdefault(prim, r)
|
||||
register_standard_rewrite = lambda prim: \
|
||||
_rewrite_rules.setdefault(prim, partial(_standard_rewrite_rule, prim))
|
||||
register_norewrite = lambda p: \
|
||||
_rewrite_rules.setdefault(p, partial(_no_rewrite, p, _check_rules[p]))
|
||||
|
||||
_check_rules: dict[core.Primitive, Callable] = {}
|
||||
register_check = lambda prim: lambda rule: _check_rules.setdefault(prim, rule)
|
||||
register_standard_check = \
|
||||
lambda prim: _check_rules.setdefault(prim, partial(_standard_check, prim))
|
||||
|
||||
def _no_rewrite(prim, rule, mesh, in_rep, *args, **params):
|
||||
out_vals = prim.bind(*args,**params)
|
||||
out_rep = rule(mesh, *in_rep, **params)
|
||||
if prim.multiple_results:
|
||||
out_rep_ = out_rep if type(out_rep) is list else [out_rep] * len(out_vals)
|
||||
else:
|
||||
out_vals, out_rep_ = [out_vals], [out_rep]
|
||||
return out_vals, out_rep_
|
||||
|
||||
def _standard_rewrite_rule(prim, mesh, in_rep, *args, **params):
|
||||
# The standard rewrite inserts pbroadcasts but doesn't change the primitive.
|
||||
out_rep_ = set.intersection(*in_rep) if in_rep else set(mesh.axis_names)
|
||||
args_ = [pbroadcast(x, tuple(n for n in src if n not in out_rep_))
|
||||
if src - out_rep_ else x for x, src in zip(args, in_rep)]
|
||||
out_vals_ = prim.bind(*args_, **params)
|
||||
out_rep = [out_rep_] * len(out_vals_) if prim.multiple_results else [out_rep_]
|
||||
out_vals = [out_vals_] if not prim.multiple_results else out_vals_
|
||||
return out_vals, out_rep
|
||||
|
||||
def _standard_check(prim, mesh, *in_rep, **__):
|
||||
# The standard check require args' and outputs' replications to be the same.
|
||||
if in_rep and not in_rep[:-1] == in_rep[1:]:
|
||||
raise Exception(f"Primitive {prim} requires argument replication types "
|
||||
f"to match, but got {in_rep}. Please open an issue at "
|
||||
"https://github.com/google/jax/issues")
|
||||
return in_rep[0] if in_rep else set(mesh.axis_names)
|
||||
|
||||
def register_standard_collective(prim):
|
||||
register_check(prim)(partial(_standard_collective_check, prim))
|
||||
register_rewrite(prim)(partial(_standard_collective_rewrite, prim))
|
||||
|
||||
def _standard_collective_check(prim, mesh, x_rep, *, axis_name, **params):
|
||||
# The standard collective check is varying -> varying over axis_name.
|
||||
del mesh, params
|
||||
if axis_name in x_rep:
|
||||
raise Exception(f"Collective {prim} must be applied to a device-varying "
|
||||
f"replication type, but got {x_rep} for collective acting "
|
||||
f"over axis name {axis_name}. Please open an issue at "
|
||||
"https://github.com/google/jax/issues")
|
||||
return x_rep
|
||||
|
||||
def _standard_collective_rewrite(prim, mesh, in_rep, x, axis_name, **params):
|
||||
# The standard collective rewrite may insert a pbroadcast on the input.
|
||||
if type(axis_name) is tuple: raise NotImplementedError # TODO
|
||||
if params.get('axis_index_groups') is not None: raise NotImplementedError
|
||||
x_rep, = in_rep
|
||||
if axis_name in in_rep:
|
||||
x = pbroadcast(x, (axis_name,))
|
||||
out_val = prim.bind(x, axis_name=axis_name, **params)
|
||||
return [out_val], [x_rep - {axis_name}]
|
||||
|
||||
def _standard_rep_rule(_, *in_rep, **__):
|
||||
return set.intersection(*in_rep) if in_rep else set()
|
||||
|
||||
for o in it.chain(lax.__dict__.values(), slicing.__dict__.values(),
|
||||
windowed_reductions.__dict__.values(), fft.__dict__.values(),
|
||||
linalg.__dict__.values(), ops.__dict__.values(),
|
||||
ad_util.__dict__.values(), prng.__dict__.values(),
|
||||
custom_derivatives.__dict__.values()):
|
||||
if isinstance(o, core.Primitive): register_standard(o)
|
||||
ad_util.__dict__.values(), prng.__dict__.values()):
|
||||
if isinstance(o, core.Primitive):
|
||||
register_standard_check(o)
|
||||
register_standard_rewrite(o)
|
||||
|
||||
register_standard(lax_parallel.ppermute_p) # doesn't change replication
|
||||
|
||||
@register_rule(lax_parallel.psum_p)
|
||||
def _psum_rule(_, *in_rep, axes, axis_index_groups):
|
||||
@register_check(lax_parallel.psum_p)
|
||||
def _psum_check(_, *in_rep, axes, axis_index_groups):
|
||||
assert False # should be rewritten away
|
||||
|
||||
@register_rewrite(lax_parallel.psum_p)
|
||||
def _psum_rewrite(_, in_rep, *args, axes, axis_index_groups):
|
||||
# Replace the psum with psum2, insert pbroadcasts on input, replicated output.
|
||||
if axis_index_groups is not None: raise NotImplementedError
|
||||
axes = (axes,) if not isinstance(axes, tuple) else axes
|
||||
return [r | set(axes) for r in in_rep] # introduces replication
|
||||
out_rep = [r | set(axes) for r in in_rep] # TODO determinism (and elsewhere)
|
||||
args_ = [pbroadcast(x, tuple(n for n in src if n not in dst))
|
||||
if src - dst else x for x, src, dst in zip(args, in_rep, out_rep)]
|
||||
out_val = psum2_p.bind(*args_, axes=axes, axis_index_groups=axis_index_groups)
|
||||
return out_val, out_rep
|
||||
|
||||
@register_rule(lax_parallel.all_gather_p)
|
||||
def _all_gather_rule(_, in_rep, *, all_gather_dimension, axis_name, axis_size,
|
||||
axis_index_groups, tiled):
|
||||
if axis_index_groups is not None: raise NotImplementedError
|
||||
if not tiled: raise NotImplementedError
|
||||
axis_name = (axis_name,) if not isinstance(axis_name, tuple) else axis_name
|
||||
return in_rep | set(axis_name) # introduces replication
|
||||
|
||||
@register_rule(lax_parallel.reduce_scatter_p)
|
||||
def _reduce_scatter_rule(_, in_rep, *, scatter_dimension, axis_name, axis_size,
|
||||
axis_index_groups, tiled):
|
||||
if axis_index_groups is not None: raise NotImplementedError
|
||||
if not tiled: raise NotImplementedError
|
||||
return in_rep - {axis_name} # removes replication
|
||||
@register_check(psum2_p)
|
||||
def _psum2_check(_, *in_rep, axes, axis_index_groups):
|
||||
assert type(axes) is tuple
|
||||
if any(set(axes) & r for r in in_rep):
|
||||
raise Exception("Collective psum must be applied to a device-varying "
|
||||
f"replication type, but got {in_rep} for collective acting "
|
||||
f"over axis name {axes}. Please open an issue at "
|
||||
"https://github.com/google/jax/issues")
|
||||
return [r | set(axes) for r in in_rep]
|
||||
register_norewrite(psum2_p)
|
||||
|
||||
@register_rule(lax_parallel.all_to_all_p)
|
||||
def _all_to_all_rule(_, in_rep, *, split_axis, concat_axis, axis_name,
|
||||
axis_index_groups):
|
||||
if axis_index_groups is not None: raise NotImplementedError
|
||||
return in_rep - {axis_name} # removes replication
|
||||
|
||||
@register_rule(lax_parallel.axis_index_p)
|
||||
def _axis_index_rule(mesh, *, axis_name):
|
||||
axis_name = (axis_name,) if not isinstance(axis_name, tuple) else axis_name
|
||||
@register_check(pbroadcast_p)
|
||||
def _pbroadcast_check(_, *in_rep, axes, axis_index_groups):
|
||||
assert type(axes) is tuple
|
||||
if not all(set(axes) & r for r in in_rep):
|
||||
raise Exception("Collective pbroadcast must be applied to a "
|
||||
"non-device-varying "
|
||||
f"replication type, but got {in_rep} for collective acting "
|
||||
f"over axis name {axes}. Please open an issue at "
|
||||
"https://github.com/google/jax/issues")
|
||||
return [r - set(axes) for r in in_rep]
|
||||
register_norewrite(pbroadcast_p)
|
||||
|
||||
|
||||
register_standard_collective(lax_parallel.all_gather_p)
|
||||
register_standard_collective(lax_parallel.all_to_all_p)
|
||||
register_standard_collective(lax_parallel.ppermute_p)
|
||||
register_standard_collective(lax_parallel.reduce_scatter_p)
|
||||
|
||||
|
||||
@register_check(lax_parallel.axis_index_p)
|
||||
def _axis_index_check(mesh, *, axis_name):
|
||||
axis_name = (axis_name,) if not type(axis_name) is tuple else axis_name
|
||||
return set(mesh.shape) - set(axis_name)
|
||||
register_norewrite(lax_parallel.axis_index_p)
|
||||
|
||||
@register_rule(pjit.pjit_p)
|
||||
def _pjit_rule(mesh, *in_rep, jaxpr, **kwargs):
|
||||
return _output_rep(mesh, jaxpr.jaxpr, in_rep)
|
||||
|
||||
@register_rule(debugging.debug_callback_p)
|
||||
@register_rewrite(pjit.pjit_p)
|
||||
def _pjit_rewrite(mesh, in_rep, *args, jaxpr, **kwargs):
|
||||
jaxpr_, out_rep = _replication_rewrite_nomatch(mesh, jaxpr, in_rep)
|
||||
out_vals = pjit.pjit_p.bind(*args, jaxpr=jaxpr_, **kwargs)
|
||||
return out_vals, out_rep
|
||||
|
||||
@register_check(pjit.pjit_p)
|
||||
def _pjit_check(mesh, *in_rep, jaxpr, **kwargs):
|
||||
return _check_rep(mesh, jaxpr.jaxpr, in_rep)
|
||||
|
||||
|
||||
@register_check(core.call_p)
|
||||
def _core_call_check(mesh, *in_rep, call_jaxpr, **kwargs):
|
||||
return _check_rep(mesh, call_jaxpr, in_rep)
|
||||
|
||||
|
||||
@register_check(debugging.debug_callback_p)
|
||||
def _debug_callback_rule(mesh, *in_rep, **_):
|
||||
return []
|
||||
register_norewrite(debugging.debug_callback_p)
|
||||
|
||||
@register_rule(callback.pure_callback_p)
|
||||
def _pure_callback_rule(mesh, *in_rep, result_avals, **_):
|
||||
|
||||
@register_check(callback.pure_callback_p)
|
||||
def _pure_callback_rule(mesh, *_, result_avals, **__):
|
||||
return [set()] * len(result_avals)
|
||||
register_norewrite(callback.pure_callback_p)
|
||||
|
||||
@register_rule(dispatch.device_put_p)
|
||||
def _device_put_rep_rule(mesh, x, *, src, device):
|
||||
|
||||
@register_check(dispatch.device_put_p)
|
||||
def _device_put_rule(mesh, x, **_):
|
||||
return x
|
||||
register_norewrite(dispatch.device_put_p)
|
||||
|
||||
@register_rule(control_flow.loops.scan_p)
|
||||
def _scan_rule(mesh, *in_rep, jaxpr, num_consts, num_carry, linear, length,
|
||||
reverse, unroll):
|
||||
const_rep, carry_rep, xs_rep = split_list(in_rep, [num_consts, num_carry])
|
||||
|
||||
@register_check(ad.custom_lin_p)
|
||||
def _custom_lin_rule(mesh, *_, out_avals, **__):
|
||||
return [set()] * len(out_avals)
|
||||
register_norewrite(ad.custom_lin_p)
|
||||
|
||||
|
||||
@register_check(control_flow.loops.scan_p)
|
||||
def _scan_check(mesh, *in_rep, jaxpr, num_consts, num_carry, **_):
|
||||
_, carry_rep_in, _ = split_list(in_rep, [num_consts, num_carry])
|
||||
out_rep = _check_rep(mesh, jaxpr.jaxpr, in_rep)
|
||||
carry_rep_out, _ = split_list(out_rep, [num_carry])
|
||||
if not carry_rep_in == carry_rep_out:
|
||||
raise Exception("Scan carry input and output got mismatched replication "
|
||||
f"types {carry_rep_in} and {carry_rep_out}. Please open an "
|
||||
"issue at https://github.com/google/jax/issues")
|
||||
return out_rep
|
||||
|
||||
@register_rewrite(control_flow.loops.scan_p)
|
||||
def _scan_rewrite(mesh, in_rep, *args, jaxpr, num_consts, num_carry, **params):
|
||||
const_rep, carry_rep_in, xs_rep = split_list(in_rep, [num_consts, num_carry])
|
||||
for _ in range(1 + num_carry):
|
||||
out_rep = _output_rep(mesh, jaxpr.jaxpr, [*const_rep, *carry_rep, *xs_rep])
|
||||
if carry_rep == out_rep[:num_carry]:
|
||||
in_rep_ = [*const_rep, *carry_rep_in, *xs_rep]
|
||||
_, out_rep = _replication_rewrite_nomatch(mesh, jaxpr, in_rep_)
|
||||
carry_rep_out, ys_rep = split_list(out_rep, [num_carry])
|
||||
carry_rep_out = map(op.and_, carry_rep_in, carry_rep_out)
|
||||
if carry_rep_in == carry_rep_out:
|
||||
break
|
||||
else:
|
||||
carry_rep = map(op.and_, carry_rep, out_rep[:num_carry])
|
||||
carry_rep_in = carry_rep_out
|
||||
else:
|
||||
assert False, 'Fixpoint not reached'
|
||||
return out_rep
|
||||
|
||||
args = [pbroadcast(x, tuple(n for n in src if n not in dst))
|
||||
if src - dst else x for x, src, dst in zip(args, in_rep, in_rep_)]
|
||||
out_rep = [*carry_rep_out, *ys_rep]
|
||||
jaxpr_ = _replication_rewrite_match(mesh, jaxpr, in_rep_, out_rep)
|
||||
|
||||
out_vals = control_flow.loops.scan_p.bind(
|
||||
*args, jaxpr=jaxpr_, num_consts=num_consts, num_carry=num_carry, **params)
|
||||
return out_vals, out_rep
|
||||
|
||||
|
||||
@register_rewrite(core.closed_call_p)
|
||||
def _closed_call_rewrite(mesh, in_rep, *args, call_jaxpr, **kwargs):
|
||||
new_jaxpr, out_rep = _replication_rewrite_nomatch(mesh, call_jaxpr, in_rep)
|
||||
out_vals = core.closed_call_p.bind(*args, jaxpr=new_jaxpr, **kwargs)
|
||||
return out_vals, out_rep
|
||||
|
||||
@register_check(core.closed_call_p)
|
||||
def _closed_call_check(mesh, *in_rep, call_jaxpr, **kwargs):
|
||||
return _check_rep(mesh, call_jaxpr.jaxpr, in_rep)
|
||||
|
||||
|
||||
@register_check(custom_derivatives.custom_jvp_call_p)
|
||||
def _custom_jvp_call_check(mesh, *in_rep, call_jaxpr, jvp_jaxpr_thunk,
|
||||
num_consts, symbolic_zeros):
|
||||
return _check_rep(mesh, call_jaxpr.jaxpr, in_rep)
|
||||
|
||||
@register_rewrite(custom_derivatives.custom_vjp_call_jaxpr_p)
|
||||
def _custom_vjp_call_jaxpr_rewrite(
|
||||
mesh, in_rep, *args, fun_jaxpr, fwd_jaxpr_thunk, bwd, num_consts, out_trees,
|
||||
symbolic_zeros):
|
||||
if symbolic_zeros:
|
||||
msg = "Please open an issue at https://github.com/google/jax/issues !"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
fun_jaxpr_, out_rep = _replication_rewrite_nomatch(mesh, fun_jaxpr, in_rep)
|
||||
_, in_rep_ = split_list(in_rep, [num_consts])
|
||||
out_rep2 = []
|
||||
|
||||
@pe._memoize
|
||||
def fwd_jaxpr_thunk_(*zeros):
|
||||
fwd_jaxpr = core.ClosedJaxpr(*fwd_jaxpr_thunk(*zeros))
|
||||
fwd_jaxpr_, out_rep = _replication_rewrite_nomatch(mesh, fwd_jaxpr, in_rep_)
|
||||
out_rep2.append(out_rep)
|
||||
return fwd_jaxpr_.jaxpr, fwd_jaxpr_.consts
|
||||
|
||||
bwd_ = _rewrite_bwd(bwd, mesh, lambda: out_rep2[0], in_rep_)
|
||||
|
||||
outs = custom_derivatives.custom_vjp_call_jaxpr_p.bind(
|
||||
*args, fun_jaxpr=fun_jaxpr_, fwd_jaxpr_thunk=fwd_jaxpr_thunk_, bwd=bwd_,
|
||||
num_consts=num_consts, out_trees=out_trees, symbolic_zeros=symbolic_zeros)
|
||||
out_rep = out_rep2[0] if out_rep2 else out_rep
|
||||
return outs, out_rep
|
||||
|
||||
@register_check(custom_derivatives.custom_vjp_call_jaxpr_p)
|
||||
def _custom_vjp_call_jaxpr_check(mesh, *in_rep, fun_jaxpr, **_):
|
||||
return _check_rep(mesh, fun_jaxpr.jaxpr, in_rep)
|
||||
|
||||
|
||||
del _check_rules[lax.tie_p]
|
||||
|
||||
@register_check(lax.tie_p)
|
||||
def _tie_check(mesh, x_rep, y_rep):
|
||||
return x_rep
|
||||
register_norewrite(lax.tie_p)
|
||||
|
||||
|
||||
# Batching
|
||||
|
||||
@ -853,12 +1156,13 @@ def _shard_map_batch(
|
||||
in_names: tuple[AxisNames, ...],
|
||||
out_names_thunk: Callable[[], tuple[AxisNames, ...]],
|
||||
check_rep: bool,
|
||||
rewrite: bool,
|
||||
auto: frozenset) -> Sequence[batching.BatchTracer]:
|
||||
in_vals, in_dims = unzip2((t.val, t.batch_dim) for t in in_tracers)
|
||||
if all(bdim is batching.not_mapped for bdim in in_dims):
|
||||
return prim.bind(fun, *in_vals, mesh=mesh, in_names=in_names,
|
||||
out_names_thunk=out_names_thunk, check_rep=check_rep,
|
||||
auto=auto)
|
||||
rewrite=rewrite, auto=auto)
|
||||
if any(isinstance(d, batching.RaggedAxis) for d in in_dims):
|
||||
raise NotImplementedError
|
||||
fun, out_dims = batching.batch_subtrace(fun, trace.main, tuple(in_dims))
|
||||
@ -874,7 +1178,7 @@ def _shard_map_batch(
|
||||
|
||||
new_params = dict(mesh=mesh, in_names=new_in_names,
|
||||
out_names_thunk=new_out_names_thunk, check_rep=check_rep,
|
||||
auto=auto)
|
||||
rewrite=rewrite, auto=auto)
|
||||
out_vals = prim.bind(fun, *in_vals, **new_params)
|
||||
make_tracer = partial(batching.BatchTracer, trace,
|
||||
source_info=source_info_util.current())
|
||||
@ -882,8 +1186,8 @@ def _shard_map_batch(
|
||||
batching.BatchTrace.process_shard_map = _shard_map_batch
|
||||
|
||||
def _shard_map_batch_post_process(trace, out_tracers, mesh, in_names,
|
||||
out_names_thunk, check_rep, auto):
|
||||
del mesh, in_names, out_names_thunk, check_rep, auto
|
||||
out_names_thunk, check_rep, rewrite, auto):
|
||||
del mesh, in_names, out_names_thunk, check_rep, rewrite, auto
|
||||
vals, dims, srcs = unzip3((t.val, t.batch_dim, t.source_info)
|
||||
for t in out_tracers)
|
||||
m = trace.main
|
||||
@ -906,7 +1210,7 @@ def _batch_out_names(spmd_axis_name, dims, out_names):
|
||||
# Autodiff
|
||||
|
||||
def _shard_map_jvp(trace, shard_map_p, f, tracers, mesh, in_names,
|
||||
out_names_thunk, check_rep, auto):
|
||||
out_names_thunk, check_rep, rewrite, auto):
|
||||
primals, tangents = unzip2((t.primal, t.tangent) for t in tracers)
|
||||
which_nz = [ type(t) is not ad.Zero for t in tangents]
|
||||
tangents = [t if type(t) is not ad.Zero else None for t in tangents]
|
||||
@ -921,7 +1225,7 @@ def _shard_map_jvp(trace, shard_map_p, f, tracers, mesh, in_names,
|
||||
return (*out_ax, *(ax for ax, nz in zip(out_ax, which_nz_out()) if nz))
|
||||
params = dict(mesh=mesh, in_names=(*in_names, *tangent_in_names),
|
||||
out_names_thunk=new_out_names_thunk, check_rep=check_rep,
|
||||
auto=auto)
|
||||
rewrite=rewrite, auto=auto)
|
||||
f_jvp, out_tree = ad.traceable(f_jvp, in_tree)
|
||||
result = shard_map_p.bind(f_jvp, *args, **params)
|
||||
primal_out, tangent_out = tree_unflatten(out_tree(), result)
|
||||
@ -931,8 +1235,8 @@ def _shard_map_jvp(trace, shard_map_p, f, tracers, mesh, in_names,
|
||||
ad.JVPTrace.process_shard_map = _shard_map_jvp
|
||||
|
||||
def _shard_map_jvp_post_process(trace, out_tracers, mesh, in_names,
|
||||
out_names_thunk, check_rep, auto):
|
||||
del mesh, in_names, out_names_thunk, check_rep, auto
|
||||
out_names_thunk, check_rep, rewrite, auto):
|
||||
del mesh, in_names, out_names_thunk, check_rep, rewrite, auto
|
||||
primals, tangents = unzip2((t.primal, t.tangent) for t in out_tracers)
|
||||
out, treedef = tree_flatten((primals, tangents))
|
||||
tangents_nz = [type(t) is not ad.Zero for t in tangents]
|
||||
@ -946,7 +1250,7 @@ def _shard_map_jvp_post_process(trace, out_tracers, mesh, in_names,
|
||||
ad.JVPTrace.post_process_shard_map = _shard_map_jvp_post_process
|
||||
|
||||
def _shard_map_partial_eval(trace, shard_map_p, f, tracers, mesh, in_names,
|
||||
out_names_thunk, check_rep, auto):
|
||||
out_names_thunk, check_rep, rewrite, auto):
|
||||
in_pvals = [t.pval for t in tracers]
|
||||
in_knowns, in_avals, in_consts = pe.partition_pvals(in_pvals)
|
||||
unk_in_names, known_in_names = pe.partition_list(in_knowns, in_names)
|
||||
@ -966,7 +1270,7 @@ def _shard_map_partial_eval(trace, shard_map_p, f, tracers, mesh, in_names,
|
||||
|
||||
known_params = dict(mesh=mesh, in_names=(*known_in_names,),
|
||||
out_names_thunk=known_out_names, check_rep=check_rep,
|
||||
auto=auto)
|
||||
rewrite=rewrite, auto=auto)
|
||||
out = shard_map_p.bind(f_known, *in_consts, **known_params)
|
||||
out_knowns, out_avals_sharded, jaxpr, env = aux()
|
||||
out_consts, res = split_list(out, [len(out) - len(jaxpr.constvars)])
|
||||
@ -980,7 +1284,7 @@ def _shard_map_partial_eval(trace, shard_map_p, f, tracers, mesh, in_names,
|
||||
unk_arg_tracers = [t for t in tracers if not t.is_known()]
|
||||
unk_params = dict(mesh=mesh, in_names=unk_in_names,
|
||||
out_names=unk_out_names, jaxpr=jaxpr, check_rep=False,
|
||||
auto=auto)
|
||||
rewrite=rewrite, auto=auto)
|
||||
out_avals = map(partial(_unshard_aval, mesh), unk_out_names, out_avals_sharded)
|
||||
out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(a), None)
|
||||
for a in out_avals]
|
||||
@ -992,7 +1296,7 @@ def _shard_map_partial_eval(trace, shard_map_p, f, tracers, mesh, in_names,
|
||||
pe.JaxprTrace.process_shard_map = _shard_map_partial_eval
|
||||
|
||||
def _shard_map_partial_eval_post_process(
|
||||
trace, tracers, mesh, in_names, out_names_thunk, check_rep, auto):
|
||||
trace, tracers, mesh, in_names, out_names_thunk, check_rep, rewrite, auto):
|
||||
del check_rep
|
||||
unk_tracers = [t for t in tracers if not t.is_known()]
|
||||
jaxpr, res, env = pe.tracers_to_jaxpr([], unk_tracers)
|
||||
@ -1013,7 +1317,7 @@ def _shard_map_partial_eval_post_process(
|
||||
staged_in_names = ({0: (*mesh.axis_names,)},) * len(res) + ({},) * len(env)
|
||||
staged_params = dict(jaxpr=jaxpr_, mesh=mesh, in_names=staged_in_names,
|
||||
out_names=(*out_names_unknown,), check_rep=False,
|
||||
auto=auto)
|
||||
rewrite=rewrite, auto=auto)
|
||||
|
||||
out_avals = map(partial(_unshard_aval, mesh), out_names_unknown, out_avals_)
|
||||
out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(a), None)
|
||||
@ -1053,10 +1357,11 @@ def _promote_scalar_residuals_jaxpr(jaxpr, res):
|
||||
return jaxpr, res
|
||||
|
||||
def _shard_map_transpose(out_cts, *args, jaxpr, mesh, in_names, out_names,
|
||||
check_rep, auto):
|
||||
check_rep, rewrite, auto):
|
||||
mb_div = lambda x, y: x / y if y != 1 else x
|
||||
out_cts = [ad.Zero(_shard_aval(mesh, ns, x.aval)) if type(x) is ad.Zero
|
||||
else mb_div(x, math.prod(map(mesh.shape.get, _unmentioned(mesh, ns))))
|
||||
else x if rewrite
|
||||
else mb_div(x, prod(map(mesh.shape.get, _unmentioned(mesh, ns))))
|
||||
for ns, x in zip(out_names, out_cts)]
|
||||
args = [x if type(x) is not ad.UndefinedPrimal else
|
||||
ad.UndefinedPrimal(_shard_aval(mesh, ns, x.aval))
|
||||
@ -1072,9 +1377,10 @@ def _shard_map_transpose(out_cts, *args, jaxpr, mesh, in_names, out_names,
|
||||
out = ad.backward_pass(
|
||||
jaxpr_unknown.jaxpr, (), False, (), (*res_reshaped, *undefs), out_cts
|
||||
)
|
||||
return [ad.Zero(_unshard_aval(mesh, ns, x.aval)) if type(x) is ad.Zero
|
||||
else jax.lax.psum(x, tuple(_unmentioned(mesh, ns)))
|
||||
out = [ad.Zero(_unshard_aval(mesh, ns, x.aval)) if type(x) is ad.Zero
|
||||
else x if rewrite else jax.lax.psum(x, tuple(_unmentioned(mesh, ns)))
|
||||
for ns, x in zip(in_names, out)]
|
||||
return out
|
||||
|
||||
fun_trans, nz_arg_cts = ad.nonzero_outputs(fun_trans)
|
||||
fun_trans_flat, out_tree = flatten_fun_nokwargs(fun_trans, in_tree)
|
||||
@ -1088,7 +1394,8 @@ def _shard_map_transpose(out_cts, *args, jaxpr, mesh, in_names, out_names,
|
||||
|
||||
out_flat = shard_map_p.bind(
|
||||
fun_trans_flat, *all_args, mesh=mesh, in_names=tuple(new_in_names),
|
||||
out_names_thunk=new_out_names_thunk, check_rep=check_rep, auto=auto)
|
||||
out_names_thunk=new_out_names_thunk, check_rep=check_rep, rewrite=rewrite,
|
||||
auto=auto)
|
||||
return tree_unflatten(out_tree(), out_flat)
|
||||
ad.primitive_transposes[shard_map_p] = _shard_map_transpose
|
||||
|
||||
@ -1201,7 +1508,6 @@ def _shard_map_dce(used_outputs: list[bool], eqn: core.JaxprEqn
|
||||
return used_inputs, new_eqn
|
||||
pe.dce_rules[shard_map_p] = _shard_map_dce
|
||||
|
||||
|
||||
# Implementing pmap in terms of shard_map
|
||||
|
||||
def pmap(f, axis_name=None, *, in_axes=0, out_axes=0,
|
||||
@ -1277,3 +1583,148 @@ def _get_devices(p, backend):
|
||||
if jax.process_count() > 1:
|
||||
return devs[:p.global_axis_size]
|
||||
return devs[:p.local_axis_size]
|
||||
|
||||
|
||||
### Rewrite!
|
||||
|
||||
class RewriteTracer(core.Tracer):
|
||||
rep: set[AxisName]
|
||||
val: Val
|
||||
|
||||
def __init__(self, trace, rep, val):
|
||||
self._trace = trace
|
||||
self.rep = rep
|
||||
self.val = val
|
||||
|
||||
@property
|
||||
def aval(self) -> core.AbstractValue:
|
||||
return core.get_aval(self.val)
|
||||
|
||||
def full_lower(self) -> RewriteTracer:
|
||||
return self
|
||||
|
||||
def __str__(self) -> str:
|
||||
return str(self.val) # TODO(mattjj): could show replication info here
|
||||
|
||||
class RewriteTrace(core.Trace):
|
||||
mesh: Mesh
|
||||
dyna: int
|
||||
|
||||
def __init__(self, *args, mesh, dyna):
|
||||
super().__init__(*args)
|
||||
self.mesh = mesh
|
||||
self.dyna = dyna
|
||||
|
||||
def pure(self, val) -> RewriteTracer:
|
||||
return RewriteTracer(self, set(self.mesh.axis_names), val)
|
||||
|
||||
def lift(self, tracer: core.Tracer) -> RewriteTracer:
|
||||
return RewriteTracer(self, set(self.mesh.axis_names), tracer)
|
||||
|
||||
def sublift(self, tracer: core.Tracer) -> RewriteTracer:
|
||||
return RewriteTracer(self, tracer.rep, tracer.val)
|
||||
|
||||
def process_primitive(self, prim, in_tracers, params):
|
||||
rule = _rewrite_rules.get(prim, partial(_rule_missing, prim))
|
||||
in_vals, in_reps = unzip2((t.val, t.rep) for t in in_tracers)
|
||||
with core.new_dynamic(self.dyna):
|
||||
out_vals, out_reps = rule(self.mesh, in_reps, *in_vals, **params)
|
||||
out_tracers = map(partial(RewriteTracer, self), out_reps, out_vals)
|
||||
return out_tracers if prim.multiple_results else out_tracers[0]
|
||||
|
||||
def process_call(self, call_primitive, f, in_tracers, params):
|
||||
in_vals, in_reps = unzip2((t.val, t.rep) for t in in_tracers)
|
||||
f, out_reps = _rewrite_subtrace(f, self.main, tuple(in_reps))
|
||||
with core.new_dynamic(self.dyna):
|
||||
out_vals = call_primitive.bind(f, *in_vals, **params)
|
||||
return map(partial(RewriteTracer, self), out_reps(), out_vals)
|
||||
|
||||
def post_process_call(self, call_primitive, out_tracers, params):
|
||||
assert False # unreachable
|
||||
|
||||
def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros):
|
||||
if symbolic_zeros:
|
||||
msg = "Please open an issue at https://github.com/google/jax/issues !"
|
||||
raise NotImplementedError(msg)
|
||||
in_vals, in_reps = unzip2((t.val, t.rep) for t in tracers)
|
||||
fun, out_reps1 = _rewrite_subtrace(fun, self.main, in_reps)
|
||||
jvp, out_reps2 = _rewrite_subtrace(jvp, self.main, in_reps * 2)
|
||||
with core.new_dynamic(self.dyna):
|
||||
out_vals = prim.bind(fun, jvp, *in_vals, symbolic_zeros=symbolic_zeros)
|
||||
fst, out_reps = lu.merge_linear_aux(out_reps1, out_reps2)
|
||||
if not fst:
|
||||
assert out_reps == out_reps[:len(out_reps) // 2] * 2
|
||||
out_reps = out_reps[:len(out_reps) // 2]
|
||||
return map(partial(RewriteTracer, self), out_reps, out_vals)
|
||||
|
||||
def post_process_custom_jvp_call(self, out_tracers, jvp_was_run):
|
||||
assert False # unreachable
|
||||
|
||||
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees,
|
||||
symbolic_zeros):
|
||||
if symbolic_zeros:
|
||||
msg = "Please open an issue at https://github.com/google/jax/issues !"
|
||||
raise NotImplementedError(msg)
|
||||
in_vals, in_reps = unzip2((t.val, t.rep) for t in tracers)
|
||||
fun, out_reps1 = _rewrite_subtrace(fun, self.main, in_reps)
|
||||
fwd_in_reps = [r_ for r in in_reps for r_ in [r, set(self.mesh.axis_names)]]
|
||||
fwd, out_reps2 = _rewrite_subtrace(fwd, self.main, fwd_in_reps)
|
||||
bwd = _rewrite_bwd(bwd, self.mesh, out_reps2, in_reps)
|
||||
with core.new_dynamic(self.dyna):
|
||||
out_vals = prim.bind(fun, fwd, bwd, *in_vals, out_trees=out_trees,
|
||||
symbolic_zeros=symbolic_zeros)
|
||||
fst, out_reps = lu.merge_linear_aux(out_reps1, out_reps2)
|
||||
if not fst:
|
||||
_, res_tree = out_trees()
|
||||
_, out_reps = split_list(out_reps, [res_tree.num_leaves])
|
||||
return map(partial(RewriteTracer, self), out_reps, out_vals)
|
||||
|
||||
def post_process_custom_vjp_call(self, out_tracers, _):
|
||||
assert False # unreachable
|
||||
|
||||
# TODO process_axis_index
|
||||
|
||||
@lu.transformation
|
||||
def _efficient_transpose_rewrite(mesh, in_names, out_names_thunk, *args):
|
||||
in_reps = map(partial(_in_names_to_rep, mesh), in_names)
|
||||
lvl = core.dynamic_level()
|
||||
with core.new_main(RewriteTrace, dynamic=True, mesh=mesh, dyna=lvl) as main:
|
||||
t = main.with_cur_sublevel()
|
||||
in_tracers = map(partial(RewriteTracer, t), in_reps, args)
|
||||
ans = yield in_tracers, {}
|
||||
out_tracers = map(t.full_raise, ans)
|
||||
out_vals, out_reps = unzip2((t.val, t.rep) for t in out_tracers)
|
||||
del main, t, in_tracers, out_tracers, ans
|
||||
out_rep_dst = [frozenset(_unmentioned(mesh, n)) for n in out_names_thunk()]
|
||||
out_vals = [pbroadcast(x, tuple(n for n in src if n not in dst)) if src - dst
|
||||
else x for x, src, dst in zip(out_vals, out_reps, out_rep_dst)]
|
||||
yield out_vals
|
||||
|
||||
@lu.transformation_with_aux
|
||||
def _rewrite_subtrace(main, in_reps, *in_vals):
|
||||
assert len(in_reps) == len(in_vals), (len(in_reps), len(in_vals))
|
||||
t = main.with_cur_sublevel()
|
||||
in_tracers = map(partial(RewriteTracer, t), in_reps, in_vals)
|
||||
with core.new_dynamic(main.level):
|
||||
outs = yield in_tracers, {}
|
||||
out_tracers = map(t.full_raise, outs)
|
||||
out_vals, out_reps = unzip2((t.val, t.rep) for t in out_tracers)
|
||||
yield out_vals, out_reps
|
||||
|
||||
def _rewrite_bwd(bwd, mesh, in_reps, reps_dst):
|
||||
def new_bwd(*args):
|
||||
lvl = core.dynamic_level()
|
||||
with core.new_main(RewriteTrace, dynamic=True, mesh=mesh, dyna=lvl) as main:
|
||||
bwd_, reps_thunk = _rewrite_subtrace(lu.wrap_init(bwd), main, in_reps())
|
||||
out = bwd_.call_wrapped(*args)
|
||||
del main
|
||||
return map(_match_replication, reps_thunk(), reps_dst, out)
|
||||
return new_bwd
|
||||
|
||||
def _match_replication(src, dst, x):
|
||||
if dst - src:
|
||||
x, = psum2_p.bind(x, axes=tuple(n for n in dst if n not in src),
|
||||
axis_index_groups=None)
|
||||
if src - dst:
|
||||
x = pbroadcast(x, tuple(n for n in src if n not in dst))
|
||||
return x
|
||||
|
@ -37,6 +37,7 @@ from jax._src import xla_bridge
|
||||
from jax._src.util import safe_zip, safe_map, partition_list, merge_lists
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import tree_util
|
||||
import jax.numpy as jnp
|
||||
|
||||
@ -113,9 +114,12 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
mesh, a, _ = create_inputs(P('z', ('x', 'y')), P(None, None))
|
||||
assert a.device_buffers[0].shape == (4, 2)
|
||||
|
||||
# NOTE(mattjj): to use out_specs=P(None, ('x', 'y')), we need to use
|
||||
# all_gather_invariant primitive, which differs in its output replication
|
||||
# type compared to all_gather.
|
||||
@jax.jit
|
||||
@partial(shard_map, mesh=mesh,
|
||||
in_specs=(P('z', ('x', 'y')),), out_specs=P(None, ('x', 'y')))
|
||||
in_specs=(P('z', ('x', 'y')),), out_specs=P('z', ('x', 'y')))
|
||||
def fwd(a):
|
||||
return lax.all_gather(a, 'z', axis=0, tiled=True)
|
||||
|
||||
@ -559,8 +563,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
def test_partial_eval_custom_axis_env(self):
|
||||
mesh = Mesh(jax.devices(), ('i',))
|
||||
|
||||
@partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i'),
|
||||
check_rep=False) # check_rep=False b/c no scan rep rule yet
|
||||
@partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i'))
|
||||
def f(_):
|
||||
_, idx = jax.lax.scan(lambda _, __: (None, jax.lax.axis_index('i')),
|
||||
None, None, length=1)
|
||||
@ -861,6 +864,274 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
self.assertIn("\"[('i',)]\"", mhlo_str)
|
||||
self.assertIn("-> (tensor<1xf32> {jax.result_info = \"[('i',)]\"})", mhlo_str)
|
||||
|
||||
def test_rewrite_process_call(self):
|
||||
def f(x):
|
||||
return core.call_p.bind(lu.wrap_init(lambda x: [2. * x]), x)[0] * x
|
||||
|
||||
mesh = jtu.create_global_mesh((4,), ('x',))
|
||||
g = shard_map(f, mesh, in_specs=(P('x'),), out_specs=P('x'))
|
||||
x = jnp.arange(4.)
|
||||
y = jax.jit(g)(x) # eager requires shmap to have ShardMapTrace.process_call
|
||||
self.assertAllClose(y, 2 * x * x, check_dtypes=True)
|
||||
|
||||
def test_rewrite_post_process_call(self):
|
||||
# We shouldn't hit post_process_call here because of RewriteTrace's dynamic
|
||||
# behavior (i.e. no data dependence).
|
||||
mesh = jtu.create_global_mesh((4,), ('x',))
|
||||
|
||||
@jax.jit
|
||||
@partial(shard_map, mesh=mesh, in_specs=(P('x'),), out_specs=P('x'))
|
||||
def f(x):
|
||||
return core.call_p.bind(lu.wrap_init(lambda: [2. * x]))[0] * x
|
||||
|
||||
x = jnp.arange(4.)
|
||||
y = f(x)
|
||||
self.assertAllClose(y, 2 * x * x, check_dtypes=True)
|
||||
|
||||
def test_rewrite_process_custom_jvp_call(self):
|
||||
@jax.custom_jvp
|
||||
def foo(x):
|
||||
return 2. * x
|
||||
|
||||
@foo.defjvp
|
||||
def foo_jvp(primals, tangents):
|
||||
(x,), (x_dot,) = primals, tangents
|
||||
return foo(x), 2. * x_dot
|
||||
|
||||
mesh = jtu.create_global_mesh((4,), ('x',))
|
||||
g = shard_map(lambda x: foo(x) * x, mesh,
|
||||
in_specs=(P('x'),), out_specs=P('x'))
|
||||
x = jnp.arange(4.)
|
||||
|
||||
y = jax.jit(g)(x)
|
||||
self.assertAllClose(y, 2 * x * x, check_dtypes=True)
|
||||
|
||||
y2, y_dot = jax.jvp(jax.jit(g), (x,), (3 * x,))
|
||||
self.assertAllClose(y2, 2 * x * x, check_dtypes=True)
|
||||
self.assertAllClose(y_dot, 2 * 2 * 3 * x * x, check_dtypes=True)
|
||||
|
||||
def test_rewrite_process_custom_vjp_call(self):
|
||||
@jax.custom_vjp
|
||||
def foo(x):
|
||||
return 2. * x
|
||||
|
||||
def foo_fwd(x):
|
||||
return foo(x), None
|
||||
|
||||
def foo_bwd(_, y_bar):
|
||||
return 2. * y_bar,
|
||||
|
||||
foo.defvjp(foo_fwd, foo_bwd)
|
||||
|
||||
mesh = jtu.create_global_mesh((4,), ('x',))
|
||||
g = shard_map(lambda x: foo(x) * x, mesh,
|
||||
in_specs=(P('x'),), out_specs=P('x'))
|
||||
|
||||
x = jnp.arange(4.)
|
||||
y = jax.jit(g)(x)
|
||||
self.assertAllClose(y, 2 * x * x, check_dtypes=True)
|
||||
|
||||
y_, x_bar = jax.value_and_grad(lambda x: jax.jit(g)(x).sum())(x)
|
||||
self.assertAllClose(y_, (2 * x * x).sum(), check_dtypes=True)
|
||||
self.assertAllClose(x_bar, 2 * 2 * x, check_dtypes=True)
|
||||
|
||||
def test_rewrite_process_custom_vjp_call_match_more_replicated(self):
|
||||
@jax.custom_vjp
|
||||
def foo(x):
|
||||
return 2. * x
|
||||
|
||||
def foo_fwd(x):
|
||||
return foo(x), None
|
||||
|
||||
def foo_bwd(_, y_bar):
|
||||
return jnp.ones_like(y_bar), # diff! more replicated than primal/tangent
|
||||
|
||||
foo.defvjp(foo_fwd, foo_bwd)
|
||||
|
||||
mesh = jtu.create_global_mesh((4,), ('x',))
|
||||
g = shard_map(lambda x: foo(x) * x, mesh,
|
||||
in_specs=(P('x'),), out_specs=P('x'))
|
||||
x = jnp.arange(4.)
|
||||
|
||||
y = jax.jit(g)(x)
|
||||
self.assertAllClose(y, 2 * x * x, check_dtypes=True)
|
||||
|
||||
y_, x_bar = jax.value_and_grad(lambda x: jax.jit(g)(x).sum())(x)
|
||||
self.assertAllClose(y_, (2 * x * x).sum(), check_dtypes=True)
|
||||
self.assertAllClose(x_bar, jnp.ones_like(x) + 2 * x, check_dtypes=True)
|
||||
|
||||
def test_rewrite_process_custom_vjp_call_match_less_replicated(self):
|
||||
@jax.custom_vjp
|
||||
def foo(x, y):
|
||||
del y
|
||||
return 2. * x
|
||||
|
||||
def foo_fwd(x, y):
|
||||
return foo(x, y), y
|
||||
|
||||
def foo_bwd(y, _):
|
||||
return y, None # diff! x_bar less replicated than primal/tangent
|
||||
|
||||
foo.defvjp(foo_fwd, foo_bwd)
|
||||
|
||||
mesh = jtu.create_global_mesh((4,), ('x',))
|
||||
g = shard_map(lambda x, y: foo(x, y) * y, mesh,
|
||||
in_specs=(P(), P('x')), out_specs=P('x'))
|
||||
x = jnp.arange(4.)
|
||||
y = jnp.arange(4 * 4.)
|
||||
|
||||
z = jax.jit(g)(x, y)
|
||||
self.assertAllClose(z, 2 * jnp.tile(x, (4,)) * y, check_dtypes=False)
|
||||
|
||||
z_, x_bar = jax.value_and_grad(lambda x, y: jax.jit(g)(x, y).sum())(x, y)
|
||||
self.assertAllClose(z.sum(), z_, check_dtypes=False)
|
||||
self.assertAllClose(x_bar, jnp.arange(16).reshape(4, 4).sum(0),
|
||||
check_dtypes=False)
|
||||
|
||||
def test_rewrite_custom_vjp_call_jaxpr(self):
|
||||
@jax.custom_vjp
|
||||
def foo(x):
|
||||
return 2. * x
|
||||
|
||||
def foo_fwd(x):
|
||||
return foo(x), None
|
||||
|
||||
def foo_bwd(_, y_bar):
|
||||
return 2. * y_bar,
|
||||
|
||||
foo.defvjp(foo_fwd, foo_bwd)
|
||||
|
||||
def foo_scan(x):
|
||||
y, _ = jax.lax.scan(lambda x, _: (foo(x), None), x, None, length=1)
|
||||
return y
|
||||
|
||||
mesh = jtu.create_global_mesh((4,), ('x',))
|
||||
g = shard_map(lambda x: foo_scan(x) * x, mesh,
|
||||
in_specs=(P('x'),), out_specs=P('x'))
|
||||
|
||||
x = jnp.arange(4.)
|
||||
y = jax.jit(g)(x)
|
||||
self.assertAllClose(y, 2 * x * x, check_dtypes=True)
|
||||
|
||||
y_, x_bar = jax.value_and_grad(lambda x: jax.jit(g)(x).sum())(x)
|
||||
self.assertAllClose(y_, (2 * x * x).sum(), check_dtypes=True)
|
||||
self.assertAllClose(x_bar, 2 * 2 * x, check_dtypes=True)
|
||||
|
||||
def test_transpose_identity(self):
|
||||
mesh = jtu.create_global_mesh((4,), ('x',))
|
||||
|
||||
@partial(shard_map, mesh=mesh, in_specs=P(), out_specs=P())
|
||||
def f(x):
|
||||
return x
|
||||
|
||||
jaxpr = jax.make_jaxpr(jax.vjp(f, 1.)[1])(1.)
|
||||
e, = jaxpr.jaxpr.eqns
|
||||
self.assertEmpty(e.params['jaxpr'].eqns)
|
||||
|
||||
jaxpr = jax.make_jaxpr(jax.vjp(jax.vjp(f, 1.)[1], 1.)[1])((1.,))
|
||||
e, = jaxpr.jaxpr.eqns
|
||||
self.assertEmpty(e.params['jaxpr'].eqns)
|
||||
|
||||
@partial(shard_map, mesh=mesh, in_specs=P(), out_specs=P())
|
||||
def g(x):
|
||||
return jax.jit(lambda x: x)(x)
|
||||
|
||||
jaxpr = jax.make_jaxpr(jax.vjp(g, 1.)[1])(1.)
|
||||
e, = jaxpr.jaxpr.eqns
|
||||
e1, e2 = e.params['jaxpr'].eqns
|
||||
self.assertEmpty(e1.outvars)
|
||||
self.assertEmpty(e2.params['jaxpr'].eqns)
|
||||
|
||||
def test_fanout_specs_transpose_to_psum(self):
|
||||
mesh = jtu.create_global_mesh((4,), ('x',))
|
||||
|
||||
@partial(shard_map, mesh=mesh, in_specs=P(), out_specs=P('x'))
|
||||
def f(x):
|
||||
return x
|
||||
|
||||
jaxpr = jax.make_jaxpr(jax.vjp(f, jnp.arange(1.))[1])(jnp.arange(4.))
|
||||
e, = jaxpr.jaxpr.eqns
|
||||
e2, = e.params['jaxpr'].eqns
|
||||
self.assertEqual(str(e2.primitive), 'psum2')
|
||||
self.assertEqual(e2.params['axes'], ('x',))
|
||||
|
||||
def test_fanin_psum_transposes_to_fanout(self):
|
||||
mesh = jtu.create_global_mesh((4,), ('x',))
|
||||
|
||||
@partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P())
|
||||
def f(x):
|
||||
return jax.lax.psum(x, 'x')
|
||||
|
||||
jaxpr = jax.make_jaxpr(jax.vjp(f, jnp.arange(4.))[1])(jnp.array([1.]))
|
||||
e, = jaxpr.jaxpr.eqns
|
||||
e1, = e.params['jaxpr'].eqns
|
||||
self.assertEqual(str(e1.primitive), 'pbroadcast')
|
||||
|
||||
def test_psum_with_implicit_fanout_self_transposes(self):
|
||||
mesh = jtu.create_global_mesh((4,), ('x',))
|
||||
|
||||
@partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P('x'))
|
||||
def f(x):
|
||||
return jax.lax.psum(x, 'x')
|
||||
|
||||
jaxpr = jax.make_jaxpr(jax.vjp(f, jnp.arange(4.))[1])(jnp.arange(4.))
|
||||
e, = jaxpr.jaxpr.eqns
|
||||
e1, e2 = e.params['jaxpr'].eqns
|
||||
self.assertEqual(str(e1.primitive), 'psum2')
|
||||
self.assertEqual(str(e2.primitive), 'pbroadcast')
|
||||
|
||||
def test_rewrite_binops(self):
|
||||
mesh = jtu.create_global_mesh((4,), ('x',))
|
||||
|
||||
@partial(shard_map, mesh=mesh, in_specs=(P(), P('x')), out_specs=P('x'))
|
||||
def f(x, y):
|
||||
return x * y
|
||||
|
||||
jaxpr = jax.make_jaxpr(f)(jnp.arange(1.), jnp.arange(4.))
|
||||
e, = jaxpr.jaxpr.eqns
|
||||
e = e.params['jaxpr'].eqns[0]
|
||||
self.assertEqual(e.primitive.name, 'pbroadcast')
|
||||
self.assertEqual(e.params['axes'], ('x',))
|
||||
|
||||
def test_rewrite_scan(self):
|
||||
mesh = jtu.create_global_mesh((4,), ('x',))
|
||||
|
||||
@partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P('x'))
|
||||
def f(x):
|
||||
x, _ = jax.lax.scan(lambda x, _: (jax.lax.psum(x, 'x'), None), x, None,
|
||||
length=2)
|
||||
return x
|
||||
|
||||
jaxpr = jax.make_jaxpr(f)(jnp.arange(4.))
|
||||
e, = jaxpr.jaxpr.eqns
|
||||
e, = e.params['jaxpr'].eqns
|
||||
e1, e2 = e.params['jaxpr'].eqns
|
||||
self.assertEqual(e1.primitive.name, 'psum2')
|
||||
self.assertEqual(e2.primitive.name, 'pbroadcast')
|
||||
|
||||
def test_check_rep_false_grads(self):
|
||||
# This test is redundant with the systematic tests below, but it serves as a
|
||||
# direct regression test for a bug.
|
||||
mesh = jtu.create_global_mesh((4,), ('heads',))
|
||||
|
||||
def f(q, k, v):
|
||||
|
||||
def body(q, k, v):
|
||||
return q * k[None, :] + v[None, :]
|
||||
|
||||
out = shard_map(body, mesh, check_rep=False,
|
||||
in_specs=(q_spec, kv_spec, kv_spec,),
|
||||
out_specs=q_spec)(q, k, v)
|
||||
return out.sum()
|
||||
|
||||
q_spec = P('heads', None)
|
||||
kv_spec = P(None)
|
||||
q = jax.device_put(jnp.arange(32.).reshape(4, 8), jax.sharding.NamedSharding(mesh, q_spec))
|
||||
k = jax.device_put(jnp.arange(8.), jax.sharding.NamedSharding(mesh, kv_spec))
|
||||
v = jax.device_put(jnp.arange(8.), jax.sharding.NamedSharding(mesh, kv_spec))
|
||||
|
||||
jtu.check_grads(f, (q, k, v), order=1, modes=['rev'], rtol=1e-2)
|
||||
|
||||
|
||||
class FunSpec(NamedTuple):
|
||||
name: str
|
||||
@ -1126,12 +1397,15 @@ class ShardMapSystematicTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(expected, out, check_dtypes=False)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
sample(config.FLAGS.jax_num_generated_cases, sample_shmap))
|
||||
(name + f'_check_rep={check_rep}', *params, check_rep)
|
||||
for (name, *params) in sample(config.FLAGS.jax_num_generated_cases, sample_shmap)
|
||||
for check_rep in [True, False]
|
||||
)
|
||||
@jax.default_matmul_precision("float32")
|
||||
def test_grads(self, fun, mesh, jit, in_specs, out_specs, args, _):
|
||||
def test_grads(self, fun, mesh, jit, in_specs, out_specs, args, _, check_rep):
|
||||
mesh = self.make_mesh(mesh)
|
||||
args = map(jnp.array, args)
|
||||
f = shard_map(fun, mesh, in_specs, out_specs)
|
||||
f = shard_map(fun, mesh, in_specs, out_specs, check_rep=check_rep)
|
||||
if jit:
|
||||
f = jax.jit(f)
|
||||
jtu.check_grads(f, args, order=2, atol=1e-2, rtol=1e-2)
|
||||
|
Loading…
x
Reference in New Issue
Block a user