rolling forward shard_map transpose fixes

The new efficient-transpose path, enabled by setting check_rep=True in the shard_map call, had kept working. But the change inadvertently broke the check_rep=False path. And because most tests set check_rep=True, we didn't notice it in the tests!

The issue was that with check_rep=False, we need the shard_map transpose rule to insert psums corresponding to in_specs with fan-out, and correspondingly insert division for out_specs with fan-in-consensus. (With the new check_rep=True path that this change adds, those extra operations aren't necessary as the body itself transposes correctly.) But the PR accidentally removed those!

The fix was simple: just track whether we've applied the efficient-transpose-body-rewrite (i.e. whether we're in the new body-is-transposable path or old need-extra-operations path) by adding a boolean parameter `rewrite` to the shard_map primitive, and if the rewrite hasn't been applied then include the explicit psum/div operations in the transpose rule.

Reverts 8a04dfd830ff89f46e1fe3e866ee4fb2da9c90aa

PiperOrigin-RevId: 561805840
This commit is contained in:
Matthew Johnson 2023-08-31 17:30:34 -07:00 committed by jax authors
parent c38f67043c
commit 70b58bbd30
10 changed files with 1399 additions and 126 deletions

View File

@ -0,0 +1,515 @@
# Efficient transposition of replication-inducing collectives
*mattjj@*, *dougalm@*
*August 2023*
## Motivation
We have an efficiency problem in automatically transposing `shmap`s containing
certain collectives. The issue arises with `psum` and `all_gather`, specifically
when the output of the collective is returned to the caller as an unmapped
output. And it's not an edge case: for example, it arises when applying `grad`
to a `shmap`-based batch data parallel neural network loss function which uses
`psum` to compute the total loss.
We've known about this problem for some time. An analogous issue exists with
`pmap`, though it's been worked around by keeping `grad` inside `pmap` rather than
outside. A primary goal of the incomplete avals-with-names work was to address a
version of this transpose efficiency problem. This doc draws on those ideas,
while extending and revising them to handle more cases and to be much easier to
land. Indeed the solution proposed here only affects the `shmap` implementation.
The rest of the system need not be changed (yet).
The main purpose of this doc is to define this transpose efficiency problem and
propose an easy-to-land solution.
This doc is not about:
* logical axis names on arrays (the only axis names here are just like in
`shmap` and OG `pmap`);
* changing autodiff semantics (all the numbers and (non)errors are staying the
same, we're just making things more efficient);
* allowing user code to reflect on any new information, or really affecting user
code at all.
## Problem: efficient transpose of `psum` or `all_gather` depends on whether cotangents are invariant across devices
Consider this semi-realistic example, meant to resemble a replicated-parameter
batch data parallel loss function:
```python
devices = jax.devices() # 8 devices
@partial(shmap, mesh=Mesh(devices, ('batch',)),
in_specs=(P(None, None), P('batch', None)),
out_specs=P())
def loss(params, batch):
inputs, targets = batch
predictions = predict(params, inputs)
local_loss = jnp.mean(jnp.sum(predictions - targets, -1))
global_loss = lax.pmean(local_loss, 'batch'))
return global_loss
```
Notice the `out_specs=P()`, which indicates an unmapped output. If you're not
familiar with the notion of unmapped outputs, see the appendix at the bottom of
this document.
Most of the details in the `loss` example aren't important. All that matters for
our purposes is that we're applying `psum` (or rather `pmean = lambda x, name:
psum(x, name) / psum(1, name)`) at the end. So a distilled version looks like
this:
```python
# Example 1: shmap involving psum and unmapped output with inefficient transpose
f1 = shmap(lambda x: psum(g(x), 'i'),
in_specs=P('i'), out_specs=P())
```
We even simplified notation by suppressing the `mesh` argument. In the examples to
follow it can be inferred from context.
What does the transpose look like? Writing `t` to mean function transpose, we
could evaluate `t(f1)(ybar)` for any `ybar` efficiently by applying the function
`¿f1_transpose?` below:
```python
# An efficient "transpose" of Example 1 (but don't transpose this again!)
¿f1_transpose? = shmap(t(g), in_specs=P(), out_specs=P('i'))
```
But that's not the transpose we currently get as t(f1).
Instead, the current recipe for transposition is roughly that we switch
`in_specs` and `out_specs`, do some division rescaling for unmapped outputs, and
transpose the body. Because `psum` is its own transpose (as an all-reduce sum),
we end up producing this transpose:
```python
# The transpose we currently get for Example 1 (which is fine to transpose again)
t(f1) = shmap(lambda ybar: t(g)(psum(ybar / 8, 'i')),
in_specs=P(), out_specs=P('i'))
```
This transpose gets the numbers right, but it's wasteful. We know statically
from the transpose's `in_specs=P()` that `ybar` has the same value for each function
instance, i.e. that its value is device-invariant for devices along the mesh
axis named `i`, and yet we apply a `psum` to it! That uses expensive communication
just to multiply the value on each device by 8. (Here 8 refers to the size of
axis i. The division by 8 comes from the original function's `out_specs=P()`; it
and the trivial `psum` basically cancel each other out.)
What are we doing wrong? We're not exploiting the fact that cotangents `ybar`
corresponding to `f1`'s unmapped outputs are guaranteed to be device-invariant;
instead, we're defensively `psum`ming them as if they weren't because `psum`'s
transpose can't be sure given the local information it has. Sometimes the `psum`
is necessary, as in transposing `f2` with respect to its first argument:
```python
# Example 2: shmap involving psum and *mapped* output with efficient transpose
f2 = shmap(lambda x, y: psum(g(x), 'i') * y,
in_specs=(P('i'), P('i')), out_specs=P('i'))
# The transpose we currently get for Example 2 is efficient
t(f2, 0) = shmap(lambda y, zbar: t(g)(psum(zbar * y, 'i')),
in_specs=(P('i'), P('i')), out_specs=P('i'))
```
Intuitively, if our transpose machinery could tell the difference between
Example 1 and Example 2, we could do better by avoiding the psum and division
where possible.
The inefficient examples can be even smaller. Consider transposing this cursed
identity function:
```python
# Example 3: cursed identity
cursed_identity = shmap(lambda x: x, P(), P())
# Currently we get these inefficient transposes
t(cursed_identity) = shmap(lambda x: psum(x / 8, 'i'), P(), P())
t(t(cursed_identity)) = shmap(lambda x: psum(psum(x / 8 / 8, 'i'), 'i')), P(), P())
...
```
It keeps getting bigger the more we transpose. How embarrassing!
And `psum` isn't the only culprit. Something analogous holds true for
`all_gather`:
```python
# Example 4: all_gather to an unmapped output
f4 = shmap(lambda x: all_gather(x, 'i'), P('i'), P())
# Currently we get this inefficient transpose
t(f4) = shmap(lambda ybar: psum_scatter(ybar / 8, 'i'), P(), P('i'))
```
This program is a bit artificial. Why do an `all_gather` and feed the result into
an unmapped output, rather than skipping the `all_gather` in the body and just
using `out_specs=P('i')` to collect the results? But even though it's cooked-up,
this example nevertheless exhibits a transpose which unnecessarily performs
communication (we could have just performed a non-communicating slice),
analogous to Example 1 for `psum`.
Also analogously to the `psum` examples, the defensive `psum_scatter` is
necessary in some cases:
```python
# Example 5: all_gather to a mapped output
f5 = shmap(lambda x, y: all_gather(x, 'i') * y,
in_specs=(P('i'), P('i')), out_specs=P('i'))
# Currently we get this efficient transpose
t(f5, 0) = shmap(lambda y, zbar: psum_scatter(zbar * y, 'i'),
in_specs=(P('i'), P('i')), out_specs=P('i'))
```
So how do we avoid these inefficient transposes?
## Solutions
Here are two solution ideas. They aren't mutually exclusive. But (spoilers) the
second one is better, and it's all we need.
### Partial solution "P-sum": build the ability to express a `psum` into `out_specs`
This solution is a bit of a strawperson because it would offer only an awkward
way to write programs. And it wouldn't even fix everything! But it's worth
considering, if only to motivate a more complete solution.
Example 4 above is artificial because we could have just used `out_specs` instead
of an `all_gather` in the body:
```python
# Example 4 again
f4 = shmap(lambda x: all_gather(x, 'i'), P('i'), P())
# Why didn't we just write it like this?
f4_better = shmap(lambda x: x, P('i'), P('i'))
```
The `f4_better` version doesn't have any transposition problems, since the
transpose problems arise from collectives in the body.
Analogously, we could fix Example 1 by extending `out_specs` so that they can
express summing:
```python
# Example 1 again
f1 = shmap(lambda x: psum(g(x), 'i'),
in_specs=P('i'), out_specs=P())
# What if we could write an output sum like this?
f1_better = shmap(g, in_specs=P('i'), out_specs=P(sum='i')) # sum='i' means sum over that axis
# Then it could transpose like this:
t(f1_better) = shmap(t(g), in_specs=P(), out_specs=P('i'))
t(t(f1_better)) = shmap(t(t(g)), in_specs=P('i'), P(sum='i'))
```
So offering `psum`s built into `out_specs` fixes the transpose problem of
Example 1. But it doesn't fully fix the cursed identity transpose in Example 3:
```python
# Example 3 again
cursed_identity = shmap(lambda x: x, P(), P())
# How it would transpose with the P-sum partial solution:
t(cursed_identity) = shmap(lambda x: x / 8, P(), P(sum='i'))
t(t(cursed_identity)) = shmap(lambda x: x / 8, P(), P(sum='i'))
```
It's an improvement since the program doesn't continue to get bigger as we keep
transposing, but we're still doing wasteful communication.
### Full solution: statically track device-varying vs device-invariant intermediates, plus new primitives
This solution has two components:
1. track when values are guaranteed to be device-invariant vs device-varying
over particular mesh axes, and
2. decompose `psum` into a two-step process, introducing a new `pbroadcast`
primitive, and introduce new primitives for `all_gather` and its transposes.
Morally, the tracking of device-invariant vs device-varying information is a
type-level consideration. But for the expedience of our first implementation, we
don't need to literally add the information to abstract values or jaxpr types.
Before we get to implementation, we'll first introduce the idea using types.
Also to follow is a discussion of making the user API convenient and backward
compatible. But to first introduce the idea, we'll ignore convenience and
instead write code that is as explicit as possible.
#### Tracking device invariance in avals (a.k.a. avals-with-names, revived)
We can sometimes tell from static information alone that the values of some
intermediate variables in the body of a `shmap` are guaranteed to be invariant
along a mesh axis, in the sense that the function instances (and their
corresponding devices) along the mesh axis must all be computing with the same
value. We'll call such values device-invariant. For values that are not
device-invariant, we'll say they're device-varying, though really we mean
potentially device-varying from the point of view of the type system.
To encode device variance in types, we'll extend the syntax of types for arrays.
We'll write things like `x:f32[3,4]{i}` to indicate that `x` is (potentially)
device-varying along mesh axis `i` (and device-invariant over any other mesh
axes of the `shmap`). More generally, we'll say the grammar for array type
syntax is something like
```
shaped_array ::= <dtype>[<int_literal>, ...]<device_variance_type>
device_variance_type ::= {<axis_name>, ...}
```
We'll also update the typing rules to handle device variance types:
* for first-order primitives other than collectives
- for multi-arity primitives, the operand device variance types must be equal
where shapes must be equal, e.g. `mul x:f32[s1]{r1} y:f32[s2][r2]` requires
`r1 == r2` in addition to `s1 == s2`
- the output device variance type must be the same as the operand(s)
* for higher-order primitives
- we just instantiate any type variables including the device variance type
(and checking types for equality checks their device variance types are
equal)
- (when performing type inference, e.g. for branches of a `cond`, we take the
union of the sets of axis names in device variance types)
* for first-order collectives
- a collective can either accept a device-varying or device-invariant input
(along a mesh axis corresponding to its axis name parameter); it's an error
to pass a device-invariant operand to a collective which accepts
device-varying operands and vice-versa
- a collective can either produce a device-varying or device-invariant output
- see the table below
As a side benefit, whatever logic implements this type checking can subsume
`shmap`'s "static analysis" check for whether a `shmap` body function is
compatible with any unmapped `out_specs`.
Here's a table summarizing the device variance typing for collective primitives:
| Name | Device variance type | Example | Lowers to HLO | Transpose |
| --- | --- | --- | --- | --- |
| `psum2` | `Varying -> Invariant` | `y:f32[3]{j} = psum(x:f32[3]{i,j}, axis='i')` | `AllReduceSum` (communication) | `pbroadcast` |
| `pbroadcast` | `Invariant -> Varying` | `y:f32[3]{i} = pbroadcast(x:f32[3], 'i')` | no-op (no communication) | `psum` |
| `all_to_all` | `Varying -> Varying` | `y:f32[16]{i} = all_to_all(x:f32[16]{i}, 'i', 0, 0)` `AllToAll` (communication) | `all_to_all` |
| `axis_index` | `() -> Varying` | `idx:i32[]{i} = axis_index('i')` | `ReplicaId` and some arithmetic (no communication) | n/a |
| `psum_scatter` | `Varying -> Varying` | `y:f32[2]{i} = psum_scatter(x:f32[16]{i}, 'i')` | `ReduceScatterSum` (communication) | `all_gather` |
| `all_gather` | `Varying -> Varying` | `y:f32[16]{i} = all_gather(x:f32[2]{i}, 'i')` | `AllGather` (communication) | `psum_scatter` |
| `pscatter` | `Invariant -> Varying` | `y:f32[2]{i} = pscatter(x:f32[16], 'i')` | `lambda x: x[axis_index('i'), None]` (no communication) | `all_gather_invariant` |
| `all_gather_invariant` | `Varying -> Invariant` | `y:f32[16] = all_gather_invariant(x:f32[2]{i}, 'i')` | `AllGather` (communication) | `pscatter` |
There are some surprising things here!
* We introduced several new primitives, including
- `pbroadcast`, which interestingly lowers to a no-op
- `all_gather_invariant`, which lowers to the same thing as `all_gather` but
has a different device variance type (essentially `all_gather` has a
`pbroadcast` fused into it, whereas `all_gather_invariant` does not)
- `pscatter` which is the dual (transpose) of `all_gather_invariant`
* all_gather has a device-varying result
Intuitively, the reason to introduce `pbroadcast` (other than to make the typing
rules work) is so that `psum` can transpose to a physical no-op. The reason we
need `all_gather` to have a device-varying result is so that we can transpose it
to `psum_scatter`; if we instead left it with a device-invariant result, we
might need a downstream `pbroadcast`, and that composition would transpose to an
inefficient `psum` followed by slicing / `pscatter`. So instead we have a
`pbroadcast` "fused into" the `all_gather`, thus allowing for an efficient
transpose into `psum_scatter`. We provide `all_gather_invariant` and its
transpose `pscatter` mainly for completeness; it's unlikely users will need it
(it corresponds to the situation in Example 4, which is easy to write
differently using `out_specs`).
Interestingly, the `psum` and `pbroadcast` transpose pair correspond to the
`psum_idrev` and `id_psumrev` that users introduced while training LLMs with
`pmap`.
#### How this system solves the inefficient transpose examples
Consider again the simplified motivating example:
```python
# Example 1 again
f1 = shmap(lambda x: psum(g(x), 'i'),
in_specs=P('i'), out_specs=P())
# Example 1 with intermediate device variance types annotated
@partial(shmap, in_specs=P('i'), out_specs=P())
def f1(x: f32[3,4]{i}):
w:f32[]{i} = g(x)
y:f32[]{} = psum(w, 'i')
return y
```
With these new rules, the transpose is:
```python
# Example 1 transpose using device variance types (go ahead and transpose this again!)
t(f1) = shmap(lambda ybar: t(g)(pbroadcast(ybar, 'i')),
in_specs=P(), out_specs=P('i'))
# Example 1 transpose with intermediate device variance types annotated
@partial(shmap, in_specs=P('i'), out_specs=P())
def f1_transpose(ybar: f32[]):
wbar:f32[]{i} = pbroadcast(ybar, 'i')
xbar:f32[3,4]{i} = transpose(g)(wbar)
return xbar
```
where evaluating the `pbroadcast` application involves no communication or FLOPs
at all; it's a no-op. Notice that if we keep transposing the body does not grow
in size; indeed `t(t(f1)) == f1`. Efficiency achieved!
And we wouldn't mess up the other examples either, so long as we `pbroadcast` to
make the types check where needed:
```python
# Example 2 rewritten with explicit pbroadcast
f2 = shmap(lambda x, y: pbroadcast(psum(g(x), 'i'), 'i') * y,
in_specs=(P('i'), P('i')), out_specs=P('i'))
# Example 2 transpose using device variance types
t(f2, 0) = shmap(lambda y, zbar: t(g)(pbroadcast(psum(zbar * y, 'i'), 'i')),
in_specs=(P('i'), P('i')), out_specs=P('i'))
# Example 3 again
cursed_identity = shmap(lambda x: x, P(), P())
# Notice here the body is `f32[...] -> f32[...]`, i.e. no device varying type.
# Example 3 transpose using device variance types
t(cursed_identity) = shmap(lambda x: x, P(), P())
t(t(cursed_identity)) = shmap(lambda x: x, P(), P())
```
Intuitively, in Example 1 we now only have "half the original psum", whereas in
Example 2 we get both "halves". For Example 3 we never need any operations in
the body at all.
For the `all_gather` examples, Example 4 would need to use
`all_reduce_invariant` to have an efficient transpose (though it'd be better to
instead use `out_specs` instead of the collective in the body):
```python
# Example 4 rewritten with explicit all_reduce_invariant
f4 = shmap(lambda x: all_gather_invariant(x, 'i'), P('i'), P())
# Example 4 with intermediate device variance types annotated
@partial(shmap, P('i'), P())
def f4(x:f32[1]{i}):
y:f32[8]{} = all_gather_invariant(x, 'i')
return y
# Example 4 transpose with intermediate device variance types annotated
@partial(shmap, in_specs=P(), out_specs=P('i'))
def f4_transpose(ybar:f32[8]):
xbar:f32[1]{i} = pscatter(ybar, 'i')
return xbar
```
For Example 5, using the device-varying `all_gather` works as we'd want:
```python
# Example 5 with intermediate device variance types annotated
@partial(shmap, in_specs=(P('i'), P('i')), out_specs=P('i'))
def f5(x:f32[1]{i}, y:f32[8]{i}):
z:f32[8]{i} = all_gather(x, 'i')
w:f32[8]{i} = z * y
return w
# Transpose with respect to first argument
@partial(shmap, in_specs=(P('i'), P('i')), out_specs=P('i'))
def f5_transpose(y:f32[8]{i}, wbar:f32[8]{i}):
zbar:f32[8]{i} = wbar * y
xbar:f32[1]{i} = psum_scatter(zbar, 'i')
return xbar
```
### How to make the API convenient for users (and backward compatible)
But what user wants to write `pbroadcast`s? And what developer wants to break
lots of existing user code involving `psum`s which are not fed into unmapped
outputs? Not me!
Instead we can automatically insert the `pbroadcast`s. It's a bit analogous to how
we do automatic rank promotion at the `jax.numpy` layer, inserting broadcasts to
avoid rank mismatch errors in binary operators. But it's much simpler since we
don't need to contend with shape tuples. The typical rule is: whenever we see a
multi-arity operation where the operands disagree in their device variance
types, take the union of operands' device variance types' axis name sets and
insert `pbroadcast`s to lift each operand to the resulting device variance type.
Automatically inserting `pbroadcast`s just before they're needed may mean we apply
the same `pbroadcast` to the same operand multiple times, creating common
subexpressions. When we transpose, those could turn into a sum-of-`psum`s rather
than a `psum`-of-sum. We'll rely on the compiler to clean that up as appropriate.
If it's a problem then we could add some simple memoization to the
`pbroadcast`-insertion pass.
The user API for `all_gather` will mean `all_gather_p` by default (not
`all_gather_invariant_p`), covering the common case and meaning no `pbroadcast`s
must be inserted.
We can provide an option on `shmap` to disable this automatic insertion of
`pbroadcast`s, in which case it'll be up to the user to ensure type-correctness.
This explicit option may be appealing to some who want to be explicit about
where the `psum`s occur in the backward pass.
### How to implement the solution
The key to making the implementation lightweight is that **we aren't going to
add these types to avals or jaxprs**. At least, not at first. That can be
expensive because it requires updating the rest of JAX, e.g. all consumers of
avals and jaxprs may need to handle the new types. We're not falling for that
again!
Instead we're going to keep these extended types as metadata internal to
`shmap`, just like the current "replication checking for `out_specs`" machinery
is internal to `shmap`. Indeed this solution amounts to a relatively small
extension to that existing machinery: it was already tracking the same
information; now we're just adding the `pbroadcast`s.
We have at least two options for where to perform the `pbroadcast` insertion:
1. just before transposition, in the transpose rule, where we have a jaxpr of
the computation to be transposed;
2. in every `shmap` body, whether eagerly executed or staged out, like the
current "replication checking for `out_specs`" machinery.
The former may end up being easier since we only have to handle the jaxpr case,
and only linear primitives. But we'll start by trying the latter so the
implementation here is a strict revision/extension to the existing
replication-checking logic.
## Appendix: defining and motivating maps with unmapped inputs and outputs
For concreteness, we'll mostly focus on `shmap`, though these same ideas apply
to e.g. `pmap` and probably `xmap`.
An argument/input is _unmapped_ along a mesh axis when the corresponding entry
of `in_specs` doesn't mention that mesh axis's name. Logically it means that
each function instance along that mesh axis gets the same value for the
argument. To the caller, each operand is sliced according to the mesh axes over
which the operand is mapped, whereas there is no slicing for mesh axes over
which the operand is unmapped.
An output is _unmapped_ along a mesh axis when the corresponding entry of
`out_specs` doesn't mention that mesh axis's name. Logically it means each
function instance along that mesh axis must return the same value. To the
caller, each result of the `shmap` is formed by concatenating the return values
of every function instance along which the outputs are mapped, whereas for mesh
axes over which the output is unmapped only one copy of the value is used.
See [the `shmap`
JEP](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html) for examples
of unmapped inputs and outputs. For comparison, in `vmap` unmapped
inputs/outputs are indicated by using `in_axes` / `out_axes` of `None` (rather
than an `int`).
Here are reasons we like unmapped inputs and outputs for `shmap`:
* **Same expressiveness as `pjit`.** Anything `pjit` can do, the `shmap` escape
hatch should be able to do too. Or else we'd have a lacking escape hatch! If
we didnt have unmapped outputs in `shmap` then we couldn't express the same
batch-parallel loss function computations as `pjit`.
* **Closed-over inputs.** Closed-over inputs essentially correspond to unmapped
inputs, and...
* **Closure under transposition.** Once we have unmapped inputs, it's natural to
be able to transpose to unmapped outputs.
So unmapped outputs are both canonical and useful!

View File

@ -48,6 +48,7 @@ Then create a pull request that adds a file named
12049: Type Annotation Roadmap for JAX <12049-type-annotations>
14273: `shard_map` (`shmap`) for simple per-device code <14273-shard-map>
15856: `jax.extend`, an extensions module <15856-jex>
17111: Efficient transposition of `shard_map` (and other maps) <17111-shmap-transpose>
Several early JEPs were converted in hindsight from other documentation,

View File

@ -1084,8 +1084,7 @@ def _why_alive_container_info(container, obj_id) -> str:
@contextmanager
def new_main(trace_type: type[Trace],
dynamic: bool = False,
def new_main(trace_type: type[Trace], dynamic: bool = False,
**payload) -> Generator[MainTrace, None, None]:
# See comments in https://github.com/google/jax/pull/3370
stack = thread_local_state.trace_state.trace_stack
@ -1111,6 +1110,20 @@ def new_main(trace_type: type[Trace],
leaked_tracers = maybe_find_leaked_tracers(t())
if leaked_tracers: raise leaked_tracer_error("trace", t(), leaked_tracers)
@contextmanager
def new_dynamic(level: int) -> Generator[None, None, None]:
stack = thread_local_state.trace_state.trace_stack
prev_dynamic, stack.dynamic = stack.dynamic, stack.stack[level]
_update_thread_local_jit_state(stack.dynamic)
try:
yield
finally:
stack.dynamic = prev_dynamic
_update_thread_local_jit_state(stack.dynamic)
def dynamic_level() -> int:
return thread_local_state.trace_state.trace_stack.dynamic.level
@contextmanager
def new_base_main(trace_type: type[Trace],
**payload) -> Generator[MainTrace, None, None]:

View File

@ -853,6 +853,7 @@ def _custom_vjp_call_jaxpr_jvp(
tangents_out = ad.custom_lin_p.bind(
*res, *args_dot, num_res=res_tree.num_leaves, bwd=bwd,
out_avals=avals_out, symbolic_zeros=symbolic_zeros)
tangents_out = map(lax.tie_p.bind, primals_out, tangents_out)
tangents_out = map(ad.recast_to_float0, primals_out, tangents_out)
return primals_out, tangents_out
ad.primitive_jvps[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_jvp

View File

@ -403,6 +403,7 @@ class JVPTrace(Trace):
tangents_out = custom_lin_p.bind(
*res, *tangents_in, num_res=res_tree.num_leaves, bwd=bwd,
out_avals=avals_out, symbolic_zeros=symbolic_zeros)
tangents_out = map(jax._src.lax.lax.tie_p.bind, primals_out, tangents_out)
tangents_out = map(recast_to_float0, primals_out, tangents_out)
return map(partial(JVPTracer, self), primals_out, tangents_out)

View File

@ -4966,6 +4966,17 @@ def _empty_lower(ctx, *, dtype):
mlir.register_lowering(empty_p, _empty_lower)
tie_p = core.Primitive('tie')
tie_p.def_impl(lambda x, y: y)
tie_p.def_abstract_eval(lambda x, y: y)
mlir.register_lowering(tie_p, lambda ctx, x, y: [y])
ad.primitive_jvps[tie_p] = \
lambda primals, tangents: (tie_p.bind(*primals), tangents[-1])
ad.primitive_transposes[tie_p] = lambda ct, x, _: [None, ct]
pe.def_trivial_padding(tie_p)
batching.defvectorized(tie_p)
class BIntRules:
@staticmethod
def physical_element_aval(dtype) -> core.ShapedArray:

View File

@ -812,6 +812,7 @@ batching.axis_primitive_batchers[psum_p] = \
partial(_batched_reduction_collective, psum_p, lambda v, axis_size: axis_size * v)
core.axis_substitution_rules[psum_p] = partial(_subst_all_names_in_param, 'axes')
# We set a special bind rule for psum so that psum(1, 'i') can be evaluated at
# tracing time.
@psum_p.def_custom_bind

View File

@ -1465,6 +1465,8 @@ tf_not_yet_impl = [
"pmin",
"ppermute",
"psum",
"psum2",
"pbroadcast",
"pmax",
"pgather",
"reduce_scatter",
@ -3449,6 +3451,9 @@ def _reduce_precision(x, *, exponent_bits, mantissa_bits):
tf_impl[lax.reduce_precision_p] = _reduce_precision
tf_impl[lax_internal.tie_p] = lambda x, y: y
def _register_checkpoint_pytrees():
"""Registers TF custom container types as pytrees."""
m = tf.Module()

View File

@ -18,7 +18,7 @@ import enum
from functools import partial
import inspect
import itertools as it
import math
from math import prod
import operator as op
from typing import Any, Callable, Optional, TypeVar, Union
@ -104,7 +104,7 @@ def _shard_map(f: Callable, mesh: Mesh, in_specs: Specs,
raise e('shard_map in_specs') from None
_check_specs_vs_args(f, mesh, in_tree, in_specs, in_specs_flat, args_flat)
in_names_flat = tuple(map(_canonicalize_spec, in_specs_flat))
flat_fun, out_tree = flatten_fun_nokwargs(fun, in_tree)
fun, out_tree = flatten_fun_nokwargs(fun, in_tree)
@memoize
def out_names_thunk():
@ -119,10 +119,15 @@ def _shard_map(f: Callable, mesh: Mesh, in_specs: Specs,
e, *_ = prefix_errors(out_specs_, dummy)
raise e('shard_map out_specs') from None
return tuple(map(_canonicalize_spec, out_specs_flat))
if rewrite := check_rep:
fun = _efficient_transpose_rewrite(fun, mesh, in_names_flat, out_names_thunk)
try:
out_flat = shard_map_p.bind(
flat_fun, *args_flat, mesh=mesh, in_names=in_names_flat,
out_names_thunk=out_names_thunk, check_rep=check_rep, auto=auto)
fun, *args_flat, mesh=mesh, in_names=in_names_flat,
out_names_thunk=out_names_thunk, check_rep=check_rep, rewrite=rewrite,
auto=auto)
except _SpecError as e:
fails, = e.args
if not callable(out_specs):
@ -135,12 +140,12 @@ def _shard_map(f: Callable, mesh: Mesh, in_specs: Specs,
except _RepError as e:
fails, = e.args
if not callable(out_specs):
msg = _rep_error(f, mesh, out_tree(), out_specs, fails)
msg = _inout_rep_error(f, mesh, out_tree(), out_specs, fails)
raise ValueError(msg) from None
return tree_unflatten(out_tree(), out_flat)
return wrapped
# Internally use AxisNames = Dict[int, Tuple[AxisName, ...]], not PartitionSpecs
# Internally use AxisNames = dict[int, tuple[AxisName, ...]], not PartitionSpecs
AxisNames = dict[int, tuple[AxisName, ...]] # TODO(mattjj): make it hashable
def _canonicalize_spec(spec: PartitionSpec) -> AxisNames:
if isinstance(spec, PartitionSpec):
@ -183,7 +188,7 @@ def _check_specs_vs_args(
msg = _spec_rank_error(SpecErrorType.input, f, in_tree, in_specs, fail)
raise ValueError(msg)
in_names_flat = tuple(map(_canonicalize_spec, in_specs_flat))
fail = [a if any(a.shape[d] % math.prod(mesh.shape[n] for n in ns)
fail = [a if any(a.shape[d] % prod(mesh.shape[n] for n in ns)
for d, ns in names.items()) else no_fail
for a, names in zip(in_avals, in_names_flat)]
if any(f is not no_fail for f in fail):
@ -239,10 +244,10 @@ def _spec_divisibility_error(
f"parameter '{list(ba.arguments.keys())[arg_key.idx]}',")
names = _canonicalize_spec(spec)
for d, ns in names.items():
if aval.shape[d] % math.prod(mesh.shape[n] for n in ns):
if aval.shape[d] % prod(mesh.shape[n] for n in ns):
axis = f"axes {ns}" if len(ns) > 1 else f"axis '{ns[0]}'"
total = 'total ' if len(ns) > 1 else ''
sz = math.prod(mesh.shape[n] for n in ns)
sz = prod(mesh.shape[n] for n in ns)
msgs.append(
f"* args{keystr(fail_key)} of shape {aval.str_short()}{extra} "
f"corresponds to in_specs{keystr(spec_key)} of value {spec}, "
@ -263,7 +268,7 @@ def _spec_divisibility_error(
f"padding the input and adapting '{fun_name}' appropriately.")
return msg
def _rep_error(f: Callable, mesh: Mesh, tree: PyTreeDef, specs: Specs,
def _inout_rep_error(f: Callable, mesh: Mesh, tree: PyTreeDef, specs: Specs,
fails: list[set | NoFail]) -> str:
fun_name = getattr(f, '__name__', str(f))
msgs = []
@ -332,10 +337,11 @@ class ShardMapPrimitive(core.Primitive):
def bind(self, fun: lu.WrappedFun, *args: MaybeTracer, mesh: Mesh,
in_names: tuple[AxisNames, ...],
out_names_thunk: Callable[[], tuple[AxisNames, ...]],
check_rep: bool, auto: frozenset[AxisName]) -> Sequence[MaybeTracer]:
check_rep: bool, rewrite: bool, auto: frozenset[AxisName]
) -> Sequence[MaybeTracer]:
top_trace = core.find_top_trace(args)
fun, env_todo = process_env_traces(fun, top_trace.level, mesh, in_names,
out_names_thunk, check_rep, auto)
out_names_thunk, check_rep, rewrite, auto)
@as_hashable_function(closure=out_names_thunk)
def new_out_names_thunk():
@ -348,7 +354,8 @@ class ShardMapPrimitive(core.Primitive):
tracers = map(top_trace.full_raise, args)
outs = top_trace.process_shard_map( # pytype: disable=attribute-error
shard_map_p, fun, tracers, mesh=mesh, in_names=in_names,
out_names_thunk=new_out_names_thunk, check_rep=check_rep, auto=auto)
out_names_thunk=new_out_names_thunk, check_rep=check_rep,
rewrite=rewrite, auto=auto)
todos, _ = env_todo()
return map(core.full_lower, core.apply_todos(todos, outs))
@ -364,7 +371,7 @@ shard_map_p = ShardMapPrimitive('shard_map')
@lu.transformation_with_aux
def process_env_traces(level: int, mesh, in_names, out_names_thunk, check_rep,
auto, *args: Any):
rewrite, auto, *args: Any):
outs = yield args, {}
todos, out_names_transforms = [], []
while True:
@ -377,7 +384,7 @@ def process_env_traces(level: int, mesh, in_names, out_names_thunk, check_rep,
trace = ans._trace.main.with_cur_sublevel()
outs = map(trace.full_raise, outs)
outs, (todo, xform) = trace.post_process_shard_map(
outs, mesh, in_names, out_names_thunk, check_rep, auto)
outs, mesh, in_names, out_names_thunk, check_rep, rewrite, auto)
todos.append(todo)
out_names_transforms.append(xform)
yield outs, (tuple(todos), tuple(out_names_transforms))
@ -390,19 +397,19 @@ def _shard_map_staging(
in_names: tuple[AxisNames, ...],
out_names_thunk: Callable[[], tuple[AxisNames, ...]],
check_rep: bool,
rewrite: bool,
auto: frozenset,
) -> Sequence[pe.DynamicJaxprTracer]:
in_avals = [t.aval for t in in_tracers]
in_avals_ = map(partial(_shard_aval, mesh), in_names, in_avals)
main = trace.main
with core.new_sublevel(), core.extend_axis_env_nd(mesh.shape.items()):
jaxpr, out_avals_generic, consts = pe.trace_to_subjaxpr_dynamic(
f, main, in_avals_)
out_avals_ = map(_check_shapedarray, out_avals_generic)
jaxpr, genavals, consts = pe.trace_to_subjaxpr_dynamic(f, main, in_avals_)
out_avals_ = map(_check_shapedarray, genavals)
_check_names(out_names_thunk(), out_avals_)
if check_rep:
in_rep = map(partial(_in_names_to_rep, mesh), in_names)
out_rep = _output_rep(mesh, jaxpr, in_rep)
if check_rep:
out_rep = _check_rep(mesh, jaxpr, in_rep)
_check_reps(mesh, out_names_thunk(), out_rep)
out_avals = map(partial(_unshard_aval, mesh), out_names_thunk(), out_avals_)
source_info = source_info_util.current()
@ -415,13 +422,80 @@ def _shard_map_staging(
jaxpr = pe.convert_constvars_jaxpr(jaxpr)
params = dict(mesh=mesh, in_names=in_names_staged,
out_names=tuple(out_names_thunk()), jaxpr=jaxpr,
check_rep=check_rep, auto=auto)
check_rep=check_rep, rewrite=rewrite, auto=auto)
eqn = pe.new_jaxpr_eqn([*constvars, *invars], outvars, prim, params,
jaxpr.effects, source_info)
trace.frame.add_eqn(eqn)
return out_tracers
pe.DynamicJaxprTrace.process_shard_map = _shard_map_staging
Val = Any
# TODO(mattjj): caching
def _replication_rewrite_match(
mesh: Mesh, jaxpr: core.ClosedJaxpr, in_rep: Sequence[set[AxisName]],
out_rep_dst: Sequence[set[AxisName]],
) -> core.ClosedJaxpr:
f = lu.wrap_init(partial(_rep_rewrite, mesh, jaxpr, in_rep))
f = _match_rep(f, mesh, out_rep_dst)
with core.extend_axis_env_nd(mesh.shape.items()):
jaxpr_, _, consts = pe.trace_to_jaxpr_dynamic(f, jaxpr.in_avals)
return core.ClosedJaxpr(jaxpr_, consts)
@lu.transformation
def _match_rep(mesh: Mesh, out_rep_dst: Sequence[set[AxisName]], *args):
out_vals, out_reps = yield args, {}
_check_reps2(mesh, out_rep_dst, out_reps)
out_vals = [pbroadcast(x, tuple(n for n in src if n not in dst)) if src - dst
else x for x, src, dst in zip(out_vals, out_reps, out_rep_dst)]
yield out_vals
def _rep_rewrite(
mesh: Mesh, jaxpr_: core.ClosedJaxpr,
in_rep: Sequence[set[AxisName]], *args: Val,
) -> tuple[tuple[Val], tuple[set[AxisName]]]:
jaxpr, consts = jaxpr_.jaxpr, jaxpr_.consts
env: dict[core.Var, tuple[Val, set[AxisName]]] = {}
def read(x: core.Atom) -> tuple[Val, set[AxisName]]:
return env[x] if isinstance(x, core.Var) else (x.val, set(mesh.axis_names))
def write(v: core.Var, val: Val, rep: set[AxisName]) -> None:
env[v] = (val, rep)
map(write, jaxpr.constvars, consts, [set(mesh.axis_names)] * len(consts))
map(write, jaxpr.invars, args, in_rep)
for e in jaxpr.eqns:
rule = _rewrite_rules.get(e.primitive, partial(_rule_missing, e.primitive))
in_vals, in_reps = unzip2(map(read, e.invars))
out_vals, out_reps = rule(mesh, in_reps, *in_vals, **e.params)
map(write, e.outvars, out_vals, out_reps)
out_vals, out_reps = unzip2(map(read, jaxpr.outvars))
return out_vals, out_reps
def _rule_missing(prim: core.Primitive, *_, **__):
raise NotImplementedError(
f"No replication rule for {prim}. As a workaround, pass the "
"`check_rep=False` argument to `shard_map`. To get this fixed, open an "
"issue at https://github.com/google/jax/issues")
def _replication_rewrite_nomatch(
mesh: Mesh, jaxpr: core.ClosedJaxpr, in_rep: Sequence[set[AxisName]],
) -> tuple[core.ClosedJaxpr, list[set[AxisName]]]:
f = lu.wrap_init(partial(_rep_rewrite, mesh, jaxpr, in_rep))
f, out_rep = _grab_out_rep(f)
with core.extend_axis_env_nd(mesh.shape.items()):
jaxpr_, _, consts = pe.trace_to_jaxpr_dynamic(f, jaxpr.in_avals)
return core.ClosedJaxpr(jaxpr_, consts), list(out_rep())
@lu.transformation_with_aux
def _grab_out_rep(*args):
yield (yield args, {})
def _check_shapedarray(aval: core.AbstractValue) -> core.ShapedArray:
assert isinstance(aval, core.ShapedArray)
return aval
@ -429,7 +503,7 @@ def _check_shapedarray(aval: core.AbstractValue) -> core.ShapedArray:
def _shard_aval(mesh: Mesh, names: AxisNames, aval: core.AbstractValue
) -> core.AbstractValue:
if isinstance(aval, core.ShapedArray):
return aval.update(tuple(sz // math.prod(mesh.shape[n] for n in names.get(i, ()))
return aval.update(tuple(sz // prod(mesh.shape[n] for n in names.get(i, ()))
for i, sz in enumerate(aval.shape)))
else:
raise NotImplementedError # TODO(mattjj): add table with handlers
@ -437,7 +511,7 @@ def _shard_aval(mesh: Mesh, names: AxisNames, aval: core.AbstractValue
def _unshard_aval(mesh: Mesh, names: AxisNames, aval: core.AbstractValue
) -> core.AbstractValue:
if isinstance(aval, core.ShapedArray):
return aval.update(tuple(sz * math.prod(mesh.shape[n] for n in names.get(i, ()))
return aval.update(tuple(sz * prod(mesh.shape[n] for n in names.get(i, ()))
for i, sz in enumerate(aval.shape)),
named_shape={k: v for k, v in aval.named_shape.items()
if k not in mesh.shape})
@ -447,7 +521,7 @@ def _unshard_aval(mesh: Mesh, names: AxisNames, aval: core.AbstractValue
# Type-checking
def _shard_map_typecheck(_, *in_atoms, jaxpr, mesh, in_names, out_names,
check_rep, auto):
check_rep, rewrite, auto):
del auto # TODO(mattjj,parkers): check
for v, x, in_name in zip(jaxpr.invars, in_atoms, in_names):
if not core.typecompat(v.aval, _shard_aval(mesh, in_name, x.aval)):
@ -457,11 +531,11 @@ def _shard_map_typecheck(_, *in_atoms, jaxpr, mesh, in_names, out_names,
core.check_jaxpr(jaxpr)
if check_rep:
in_rep = map(partial(_in_names_to_rep, mesh), in_names)
out_rep = _output_rep(mesh, jaxpr, in_rep)
out_rep = _check_rep(mesh, jaxpr, in_rep)
for rep, dst in zip(out_rep, out_names):
if not _valid_repeats(mesh, rep, dst):
raise core.JaxprTypeError("shard_map can't prove output is sufficiently "
"replicated")
raise core.JaxprTypeError("shard_map can't prove output is "
"sufficiently replicated")
out_avals_sharded = [x.aval for x in jaxpr.outvars]
out_avals = map(partial(_unshard_aval, mesh), out_names, out_avals_sharded)
return out_avals, jaxpr.effects
@ -470,7 +544,7 @@ core.custom_typechecks[shard_map_p] = _shard_map_typecheck
def _in_names_to_rep(mesh: Mesh, names: AxisNames) -> set[AxisName]:
return set(mesh.axis_names) - {n for ns in names.values() for n in ns}
def _output_rep(mesh: Mesh, jaxpr: core.Jaxpr, in_rep: Sequence[set[AxisName]],
def _check_rep(mesh: Mesh, jaxpr: core.Jaxpr, in_rep: Sequence[set[AxisName]],
) -> Sequence[set[AxisName]]:
env: dict[core.Var, set[AxisName]] = {}
@ -484,7 +558,7 @@ def _output_rep(mesh: Mesh, jaxpr: core.Jaxpr, in_rep: Sequence[set[AxisName]],
map(write, jaxpr.invars, in_rep)
last_used = core.last_used(jaxpr)
for e in jaxpr.eqns:
rule = _rep_rules.get(e.primitive, partial(_rep_rule, e.primitive))
rule = _check_rules.get(e.primitive, partial(_rule_missing, e.primitive))
out_rep = rule(mesh, *map(read, e.invars), **e.params)
if e.primitive.multiple_results:
out_rep = [out_rep] * len(e.outvars) if type(out_rep) is set else out_rep
@ -500,8 +574,8 @@ def _valid_repeats(mesh: Mesh, rep: set[AxisName], dst: AxisNames) -> bool:
# Lowering
def _shard_map_lowering(ctx, *in_nodes, jaxpr, mesh, in_names, out_names,
check_rep, auto):
del check_rep
check_rep, rewrite, auto):
del check_rep, rewrite
in_avals_ = [v.aval for v in jaxpr.invars]
out_avals_ = [x.aval for x in jaxpr.outvars]
in_nodes_ = map(partial(_xla_shard, ctx, mesh, auto), in_names, ctx.avals_in,
@ -557,7 +631,7 @@ def _pspec_mhlo_attrs(names: AxisNames, aval: core.AbstractValue) -> str:
# Eager evaluation
def _shard_map_impl(trace, prim, fun, args, *, mesh, in_names, out_names_thunk,
check_rep, auto):
check_rep, rewrite, auto):
if auto: raise NotImplementedError
del prim, auto
args = map(partial(_unmatch_spec, mesh), in_names, args)
@ -572,8 +646,10 @@ def _shard_map_impl(trace, prim, fun, args, *, mesh, in_names, out_names_thunk,
del main, t, in_tracers, ans, out_tracers
out_avals = [core.mapped_aval(x.shape[0], 0, core.get_aval(x)) for x in outs_]
_check_names(out_names_thunk(), out_avals) # pytype: disable=wrong-arg-types
if check_rep: _check_reps(mesh, out_names_thunk(), out_rep)
return map(partial(_match_spec, mesh), out_rep, out_names_thunk(), outs_)
if check_rep:
_check_reps(mesh, out_names_thunk(), out_rep)
return map(partial(_match_spec, mesh, check_rep), out_rep, out_names_thunk(),
outs_)
core.EvalTrace.process_shard_map = _shard_map_impl
def _names_to_pspec(names: AxisNames) -> PartitionSpec:
@ -587,7 +663,7 @@ def _unmatch_spec(mesh: Mesh, src: AxisNames, x: JaxType) -> JaxType:
def _unmatch(mesh, src_tup, x):
src = _names_to_pspec(dict(src_tup))
dst = P(mesh.axis_names)
return shard_map(_add_singleton, mesh, (src,), dst)(x)
return shard_map(_add_singleton, mesh, (src,), dst, check_rep=False)(x)
def _check_names(names: Sequence[AxisNames], avals: Sequence[core.ShapedArray]
) -> None:
@ -602,14 +678,21 @@ def _check_reps(mesh, names, reps):
if any(f is not no_fail for f in fail): raise _RepError(fail)
class _RepError(Exception): pass
def _match_spec(mesh: Mesh, rep: set[AxisName], dst: AxisNames, x: JaxType
) -> JaxType:
with core.eval_context():
return jax.jit(HashablePartial(_match, mesh, tuple(dst.items())))(x)
def _check_reps2(mesh, reps_dest, reps):
fail = [src if not dst.issubset(src) else no_fail
for dst, src in zip(reps_dest, reps)]
if any(f is not no_fail for f in fail): raise _RepError(fail)
def _match(mesh, dst_tup, x):
def _match_spec(mesh: Mesh, check_rep: bool,
rep: set[AxisName], dst: AxisNames, x: JaxType) -> JaxType:
fn = HashablePartial(_match, mesh, check_rep, tuple(dst.items()))
with core.eval_context():
return jax.jit(fn)(x)
def _match(mesh, check_rep, dst_tup, x):
src = P(mesh.axis_names)
dst = _names_to_pspec(dict(dst_tup))
# TODO put back (?) needed for rep checking in eager? for now test rewrite
return shard_map(_rem_singleton, mesh, (src,), dst, check_rep=False)(x)
def _rem_singleton(x): return x.reshape(x.shape[1:])
@ -640,7 +723,7 @@ class ShardMapTrace(core.Trace):
f = HashablePartial(_prim_applier, prim, tuple(params.items()), self.mesh)
with core.eval_context(), jax.disable_jit(False):
out_vals = jax.jit(f)(*in_vals)
rep_rule = _rep_rules.get(prim, partial(_rep_rule, prim))
rep_rule = _check_rules.get(prim, partial(_rule_missing, prim))
out_rep = rep_rule(self.mesh, *in_rep, **params) if self.check else set()
if prim.multiple_results:
out_rep = [out_rep] * len(out_vals) if type(out_rep) is set else out_rep
@ -758,92 +841,312 @@ def _device_put_eager_rule(mesh, x, *, src, device):
f"shard_map-decorated functions, but got device {device}")
eager_rules[dispatch.device_put_p] = _device_put_eager_rule
# Static replication checking
# New primitives for efficient transposition
def _rep_rule(prim: core.Primitive, mesh: Mesh, *in_rep: set[AxisName],
**params: Any) -> set[AxisName] | list[set[AxisName]]:
raise NotImplementedError(
f"No replication rule for {prim}. As a workaround, pass the "
"`check_rep=False` argument to `shard_map`. To get this fixed, open an "
"issue at https://github.com/google/jax/issues")
# psum2_p is like psum_p except has a different transpose, so mostly copied:
psum2_p = core.AxisPrimitive('psum2')
psum2_p.multiple_results = True
psum2_p.def_impl(lax_parallel.psum_p.impl)
psum2_p.def_effectful_abstract_eval(lax_parallel.psum_p.abstract_eval)
mlir.register_lowering(psum2_p, mlir._lowerings[lax_parallel.psum_p])
batching.primitive_batchers[psum2_p] = partial(lax_parallel._reduction_batcher, psum2_p)
batching.axis_primitive_batchers[psum2_p] = \
partial(lax_parallel._batched_reduction_collective, psum2_p,
lambda v, axis_size: axis_size * v)
core.axis_substitution_rules[psum2_p] = \
partial(lax_parallel._subst_all_names_in_param, 'axes')
def _psum2_transpose_rule(cts, *args, axes, axis_index_groups):
del args
return pbroadcast_p.bind(*cts, axes=axes, axis_index_groups=axis_index_groups)
ad.deflinear2(psum2_p, _psum2_transpose_rule)
_rep_rules: dict[core.Primitive, Callable] = {}
register_rule = lambda prim: lambda rule: _rep_rules.setdefault(prim, rule)
register_standard = lambda prim: _rep_rules.setdefault(prim, _standard_rep_rule)
# pbroadcast_p is exactly the transpose of psum2_p
def pbroadcast(x, axis_name):
axes = (axis_name,) if not isinstance(axis_name, tuple) else axis_name
xs, treedef = tree_flatten(x)
ys = pbroadcast_p.bind(*xs, axes=axes, axis_index_groups=None)
return tree_unflatten(treedef, ys)
pbroadcast_p = core.AxisPrimitive('pbroadcast')
pbroadcast_p.multiple_results = True
pbroadcast_p.def_impl(lambda *args, axes, axis_index_groups: args)
pbroadcast_p.def_abstract_eval(lambda *args, axes, axis_index_groups: args)
mlir.register_lowering(pbroadcast_p, lambda ctx, *x, axes, axis_index_groups: x)
def _pbroadcast_batcher(vals_in, dims_in, *, axes, axis_index_groups):
if any(type(axis) is int for axis in axes): raise NotImplementedError
vals_out = pbroadcast_p.bind(*vals_in, axes=axes,
axis_index_groups=axis_index_groups)
return vals_out, dims_in
batching.primitive_batchers[pbroadcast_p] = _pbroadcast_batcher
def _pbroadcast_axis_batcher(size, name, trace_type, vals_in, dims_in, *, axes,
groups):
raise NotImplementedError # vmap with axis name involved in this primitive
batching.axis_primitive_batchers[pbroadcast_p] = _pbroadcast_axis_batcher
core.axis_substitution_rules[pbroadcast_p] = \
partial(lax_parallel._subst_all_names_in_param, 'axes')
ad.deflinear2(pbroadcast_p,
lambda cts, *_, axes, axis_index_groups:
psum2_p.bind(*cts, axes=axes, axis_index_groups=axis_index_groups))
# Rewrite rules and static replication checking for efficient transposition
_rewrite_rules: dict[core.Primitive, Callable] = {}
register_rewrite = lambda prim: lambda r: _rewrite_rules.setdefault(prim, r)
register_standard_rewrite = lambda prim: \
_rewrite_rules.setdefault(prim, partial(_standard_rewrite_rule, prim))
register_norewrite = lambda p: \
_rewrite_rules.setdefault(p, partial(_no_rewrite, p, _check_rules[p]))
_check_rules: dict[core.Primitive, Callable] = {}
register_check = lambda prim: lambda rule: _check_rules.setdefault(prim, rule)
register_standard_check = \
lambda prim: _check_rules.setdefault(prim, partial(_standard_check, prim))
def _no_rewrite(prim, rule, mesh, in_rep, *args, **params):
out_vals = prim.bind(*args,**params)
out_rep = rule(mesh, *in_rep, **params)
if prim.multiple_results:
out_rep_ = out_rep if type(out_rep) is list else [out_rep] * len(out_vals)
else:
out_vals, out_rep_ = [out_vals], [out_rep]
return out_vals, out_rep_
def _standard_rewrite_rule(prim, mesh, in_rep, *args, **params):
# The standard rewrite inserts pbroadcasts but doesn't change the primitive.
out_rep_ = set.intersection(*in_rep) if in_rep else set(mesh.axis_names)
args_ = [pbroadcast(x, tuple(n for n in src if n not in out_rep_))
if src - out_rep_ else x for x, src in zip(args, in_rep)]
out_vals_ = prim.bind(*args_, **params)
out_rep = [out_rep_] * len(out_vals_) if prim.multiple_results else [out_rep_]
out_vals = [out_vals_] if not prim.multiple_results else out_vals_
return out_vals, out_rep
def _standard_check(prim, mesh, *in_rep, **__):
# The standard check require args' and outputs' replications to be the same.
if in_rep and not in_rep[:-1] == in_rep[1:]:
raise Exception(f"Primitive {prim} requires argument replication types "
f"to match, but got {in_rep}. Please open an issue at "
"https://github.com/google/jax/issues")
return in_rep[0] if in_rep else set(mesh.axis_names)
def register_standard_collective(prim):
register_check(prim)(partial(_standard_collective_check, prim))
register_rewrite(prim)(partial(_standard_collective_rewrite, prim))
def _standard_collective_check(prim, mesh, x_rep, *, axis_name, **params):
# The standard collective check is varying -> varying over axis_name.
del mesh, params
if axis_name in x_rep:
raise Exception(f"Collective {prim} must be applied to a device-varying "
f"replication type, but got {x_rep} for collective acting "
f"over axis name {axis_name}. Please open an issue at "
"https://github.com/google/jax/issues")
return x_rep
def _standard_collective_rewrite(prim, mesh, in_rep, x, axis_name, **params):
# The standard collective rewrite may insert a pbroadcast on the input.
if type(axis_name) is tuple: raise NotImplementedError # TODO
if params.get('axis_index_groups') is not None: raise NotImplementedError
x_rep, = in_rep
if axis_name in in_rep:
x = pbroadcast(x, (axis_name,))
out_val = prim.bind(x, axis_name=axis_name, **params)
return [out_val], [x_rep - {axis_name}]
def _standard_rep_rule(_, *in_rep, **__):
return set.intersection(*in_rep) if in_rep else set()
for o in it.chain(lax.__dict__.values(), slicing.__dict__.values(),
windowed_reductions.__dict__.values(), fft.__dict__.values(),
linalg.__dict__.values(), ops.__dict__.values(),
ad_util.__dict__.values(), prng.__dict__.values(),
custom_derivatives.__dict__.values()):
if isinstance(o, core.Primitive): register_standard(o)
ad_util.__dict__.values(), prng.__dict__.values()):
if isinstance(o, core.Primitive):
register_standard_check(o)
register_standard_rewrite(o)
register_standard(lax_parallel.ppermute_p) # doesn't change replication
@register_rule(lax_parallel.psum_p)
def _psum_rule(_, *in_rep, axes, axis_index_groups):
@register_check(lax_parallel.psum_p)
def _psum_check(_, *in_rep, axes, axis_index_groups):
assert False # should be rewritten away
@register_rewrite(lax_parallel.psum_p)
def _psum_rewrite(_, in_rep, *args, axes, axis_index_groups):
# Replace the psum with psum2, insert pbroadcasts on input, replicated output.
if axis_index_groups is not None: raise NotImplementedError
axes = (axes,) if not isinstance(axes, tuple) else axes
return [r | set(axes) for r in in_rep] # introduces replication
out_rep = [r | set(axes) for r in in_rep] # TODO determinism (and elsewhere)
args_ = [pbroadcast(x, tuple(n for n in src if n not in dst))
if src - dst else x for x, src, dst in zip(args, in_rep, out_rep)]
out_val = psum2_p.bind(*args_, axes=axes, axis_index_groups=axis_index_groups)
return out_val, out_rep
@register_rule(lax_parallel.all_gather_p)
def _all_gather_rule(_, in_rep, *, all_gather_dimension, axis_name, axis_size,
axis_index_groups, tiled):
if axis_index_groups is not None: raise NotImplementedError
if not tiled: raise NotImplementedError
axis_name = (axis_name,) if not isinstance(axis_name, tuple) else axis_name
return in_rep | set(axis_name) # introduces replication
@register_rule(lax_parallel.reduce_scatter_p)
def _reduce_scatter_rule(_, in_rep, *, scatter_dimension, axis_name, axis_size,
axis_index_groups, tiled):
if axis_index_groups is not None: raise NotImplementedError
if not tiled: raise NotImplementedError
return in_rep - {axis_name} # removes replication
@register_check(psum2_p)
def _psum2_check(_, *in_rep, axes, axis_index_groups):
assert type(axes) is tuple
if any(set(axes) & r for r in in_rep):
raise Exception("Collective psum must be applied to a device-varying "
f"replication type, but got {in_rep} for collective acting "
f"over axis name {axes}. Please open an issue at "
"https://github.com/google/jax/issues")
return [r | set(axes) for r in in_rep]
register_norewrite(psum2_p)
@register_rule(lax_parallel.all_to_all_p)
def _all_to_all_rule(_, in_rep, *, split_axis, concat_axis, axis_name,
axis_index_groups):
if axis_index_groups is not None: raise NotImplementedError
return in_rep - {axis_name} # removes replication
@register_rule(lax_parallel.axis_index_p)
def _axis_index_rule(mesh, *, axis_name):
axis_name = (axis_name,) if not isinstance(axis_name, tuple) else axis_name
@register_check(pbroadcast_p)
def _pbroadcast_check(_, *in_rep, axes, axis_index_groups):
assert type(axes) is tuple
if not all(set(axes) & r for r in in_rep):
raise Exception("Collective pbroadcast must be applied to a "
"non-device-varying "
f"replication type, but got {in_rep} for collective acting "
f"over axis name {axes}. Please open an issue at "
"https://github.com/google/jax/issues")
return [r - set(axes) for r in in_rep]
register_norewrite(pbroadcast_p)
register_standard_collective(lax_parallel.all_gather_p)
register_standard_collective(lax_parallel.all_to_all_p)
register_standard_collective(lax_parallel.ppermute_p)
register_standard_collective(lax_parallel.reduce_scatter_p)
@register_check(lax_parallel.axis_index_p)
def _axis_index_check(mesh, *, axis_name):
axis_name = (axis_name,) if not type(axis_name) is tuple else axis_name
return set(mesh.shape) - set(axis_name)
register_norewrite(lax_parallel.axis_index_p)
@register_rule(pjit.pjit_p)
def _pjit_rule(mesh, *in_rep, jaxpr, **kwargs):
return _output_rep(mesh, jaxpr.jaxpr, in_rep)
@register_rule(debugging.debug_callback_p)
@register_rewrite(pjit.pjit_p)
def _pjit_rewrite(mesh, in_rep, *args, jaxpr, **kwargs):
jaxpr_, out_rep = _replication_rewrite_nomatch(mesh, jaxpr, in_rep)
out_vals = pjit.pjit_p.bind(*args, jaxpr=jaxpr_, **kwargs)
return out_vals, out_rep
@register_check(pjit.pjit_p)
def _pjit_check(mesh, *in_rep, jaxpr, **kwargs):
return _check_rep(mesh, jaxpr.jaxpr, in_rep)
@register_check(core.call_p)
def _core_call_check(mesh, *in_rep, call_jaxpr, **kwargs):
return _check_rep(mesh, call_jaxpr, in_rep)
@register_check(debugging.debug_callback_p)
def _debug_callback_rule(mesh, *in_rep, **_):
return []
register_norewrite(debugging.debug_callback_p)
@register_rule(callback.pure_callback_p)
def _pure_callback_rule(mesh, *in_rep, result_avals, **_):
@register_check(callback.pure_callback_p)
def _pure_callback_rule(mesh, *_, result_avals, **__):
return [set()] * len(result_avals)
register_norewrite(callback.pure_callback_p)
@register_rule(dispatch.device_put_p)
def _device_put_rep_rule(mesh, x, *, src, device):
@register_check(dispatch.device_put_p)
def _device_put_rule(mesh, x, **_):
return x
register_norewrite(dispatch.device_put_p)
@register_rule(control_flow.loops.scan_p)
def _scan_rule(mesh, *in_rep, jaxpr, num_consts, num_carry, linear, length,
reverse, unroll):
const_rep, carry_rep, xs_rep = split_list(in_rep, [num_consts, num_carry])
@register_check(ad.custom_lin_p)
def _custom_lin_rule(mesh, *_, out_avals, **__):
return [set()] * len(out_avals)
register_norewrite(ad.custom_lin_p)
@register_check(control_flow.loops.scan_p)
def _scan_check(mesh, *in_rep, jaxpr, num_consts, num_carry, **_):
_, carry_rep_in, _ = split_list(in_rep, [num_consts, num_carry])
out_rep = _check_rep(mesh, jaxpr.jaxpr, in_rep)
carry_rep_out, _ = split_list(out_rep, [num_carry])
if not carry_rep_in == carry_rep_out:
raise Exception("Scan carry input and output got mismatched replication "
f"types {carry_rep_in} and {carry_rep_out}. Please open an "
"issue at https://github.com/google/jax/issues")
return out_rep
@register_rewrite(control_flow.loops.scan_p)
def _scan_rewrite(mesh, in_rep, *args, jaxpr, num_consts, num_carry, **params):
const_rep, carry_rep_in, xs_rep = split_list(in_rep, [num_consts, num_carry])
for _ in range(1 + num_carry):
out_rep = _output_rep(mesh, jaxpr.jaxpr, [*const_rep, *carry_rep, *xs_rep])
if carry_rep == out_rep[:num_carry]:
in_rep_ = [*const_rep, *carry_rep_in, *xs_rep]
_, out_rep = _replication_rewrite_nomatch(mesh, jaxpr, in_rep_)
carry_rep_out, ys_rep = split_list(out_rep, [num_carry])
carry_rep_out = map(op.and_, carry_rep_in, carry_rep_out)
if carry_rep_in == carry_rep_out:
break
else:
carry_rep = map(op.and_, carry_rep, out_rep[:num_carry])
carry_rep_in = carry_rep_out
else:
assert False, 'Fixpoint not reached'
return out_rep
args = [pbroadcast(x, tuple(n for n in src if n not in dst))
if src - dst else x for x, src, dst in zip(args, in_rep, in_rep_)]
out_rep = [*carry_rep_out, *ys_rep]
jaxpr_ = _replication_rewrite_match(mesh, jaxpr, in_rep_, out_rep)
out_vals = control_flow.loops.scan_p.bind(
*args, jaxpr=jaxpr_, num_consts=num_consts, num_carry=num_carry, **params)
return out_vals, out_rep
@register_rewrite(core.closed_call_p)
def _closed_call_rewrite(mesh, in_rep, *args, call_jaxpr, **kwargs):
new_jaxpr, out_rep = _replication_rewrite_nomatch(mesh, call_jaxpr, in_rep)
out_vals = core.closed_call_p.bind(*args, jaxpr=new_jaxpr, **kwargs)
return out_vals, out_rep
@register_check(core.closed_call_p)
def _closed_call_check(mesh, *in_rep, call_jaxpr, **kwargs):
return _check_rep(mesh, call_jaxpr.jaxpr, in_rep)
@register_check(custom_derivatives.custom_jvp_call_p)
def _custom_jvp_call_check(mesh, *in_rep, call_jaxpr, jvp_jaxpr_thunk,
num_consts, symbolic_zeros):
return _check_rep(mesh, call_jaxpr.jaxpr, in_rep)
@register_rewrite(custom_derivatives.custom_vjp_call_jaxpr_p)
def _custom_vjp_call_jaxpr_rewrite(
mesh, in_rep, *args, fun_jaxpr, fwd_jaxpr_thunk, bwd, num_consts, out_trees,
symbolic_zeros):
if symbolic_zeros:
msg = "Please open an issue at https://github.com/google/jax/issues !"
raise NotImplementedError(msg)
fun_jaxpr_, out_rep = _replication_rewrite_nomatch(mesh, fun_jaxpr, in_rep)
_, in_rep_ = split_list(in_rep, [num_consts])
out_rep2 = []
@pe._memoize
def fwd_jaxpr_thunk_(*zeros):
fwd_jaxpr = core.ClosedJaxpr(*fwd_jaxpr_thunk(*zeros))
fwd_jaxpr_, out_rep = _replication_rewrite_nomatch(mesh, fwd_jaxpr, in_rep_)
out_rep2.append(out_rep)
return fwd_jaxpr_.jaxpr, fwd_jaxpr_.consts
bwd_ = _rewrite_bwd(bwd, mesh, lambda: out_rep2[0], in_rep_)
outs = custom_derivatives.custom_vjp_call_jaxpr_p.bind(
*args, fun_jaxpr=fun_jaxpr_, fwd_jaxpr_thunk=fwd_jaxpr_thunk_, bwd=bwd_,
num_consts=num_consts, out_trees=out_trees, symbolic_zeros=symbolic_zeros)
out_rep = out_rep2[0] if out_rep2 else out_rep
return outs, out_rep
@register_check(custom_derivatives.custom_vjp_call_jaxpr_p)
def _custom_vjp_call_jaxpr_check(mesh, *in_rep, fun_jaxpr, **_):
return _check_rep(mesh, fun_jaxpr.jaxpr, in_rep)
del _check_rules[lax.tie_p]
@register_check(lax.tie_p)
def _tie_check(mesh, x_rep, y_rep):
return x_rep
register_norewrite(lax.tie_p)
# Batching
@ -853,12 +1156,13 @@ def _shard_map_batch(
in_names: tuple[AxisNames, ...],
out_names_thunk: Callable[[], tuple[AxisNames, ...]],
check_rep: bool,
rewrite: bool,
auto: frozenset) -> Sequence[batching.BatchTracer]:
in_vals, in_dims = unzip2((t.val, t.batch_dim) for t in in_tracers)
if all(bdim is batching.not_mapped for bdim in in_dims):
return prim.bind(fun, *in_vals, mesh=mesh, in_names=in_names,
out_names_thunk=out_names_thunk, check_rep=check_rep,
auto=auto)
rewrite=rewrite, auto=auto)
if any(isinstance(d, batching.RaggedAxis) for d in in_dims):
raise NotImplementedError
fun, out_dims = batching.batch_subtrace(fun, trace.main, tuple(in_dims))
@ -874,7 +1178,7 @@ def _shard_map_batch(
new_params = dict(mesh=mesh, in_names=new_in_names,
out_names_thunk=new_out_names_thunk, check_rep=check_rep,
auto=auto)
rewrite=rewrite, auto=auto)
out_vals = prim.bind(fun, *in_vals, **new_params)
make_tracer = partial(batching.BatchTracer, trace,
source_info=source_info_util.current())
@ -882,8 +1186,8 @@ def _shard_map_batch(
batching.BatchTrace.process_shard_map = _shard_map_batch
def _shard_map_batch_post_process(trace, out_tracers, mesh, in_names,
out_names_thunk, check_rep, auto):
del mesh, in_names, out_names_thunk, check_rep, auto
out_names_thunk, check_rep, rewrite, auto):
del mesh, in_names, out_names_thunk, check_rep, rewrite, auto
vals, dims, srcs = unzip3((t.val, t.batch_dim, t.source_info)
for t in out_tracers)
m = trace.main
@ -906,7 +1210,7 @@ def _batch_out_names(spmd_axis_name, dims, out_names):
# Autodiff
def _shard_map_jvp(trace, shard_map_p, f, tracers, mesh, in_names,
out_names_thunk, check_rep, auto):
out_names_thunk, check_rep, rewrite, auto):
primals, tangents = unzip2((t.primal, t.tangent) for t in tracers)
which_nz = [ type(t) is not ad.Zero for t in tangents]
tangents = [t if type(t) is not ad.Zero else None for t in tangents]
@ -921,7 +1225,7 @@ def _shard_map_jvp(trace, shard_map_p, f, tracers, mesh, in_names,
return (*out_ax, *(ax for ax, nz in zip(out_ax, which_nz_out()) if nz))
params = dict(mesh=mesh, in_names=(*in_names, *tangent_in_names),
out_names_thunk=new_out_names_thunk, check_rep=check_rep,
auto=auto)
rewrite=rewrite, auto=auto)
f_jvp, out_tree = ad.traceable(f_jvp, in_tree)
result = shard_map_p.bind(f_jvp, *args, **params)
primal_out, tangent_out = tree_unflatten(out_tree(), result)
@ -931,8 +1235,8 @@ def _shard_map_jvp(trace, shard_map_p, f, tracers, mesh, in_names,
ad.JVPTrace.process_shard_map = _shard_map_jvp
def _shard_map_jvp_post_process(trace, out_tracers, mesh, in_names,
out_names_thunk, check_rep, auto):
del mesh, in_names, out_names_thunk, check_rep, auto
out_names_thunk, check_rep, rewrite, auto):
del mesh, in_names, out_names_thunk, check_rep, rewrite, auto
primals, tangents = unzip2((t.primal, t.tangent) for t in out_tracers)
out, treedef = tree_flatten((primals, tangents))
tangents_nz = [type(t) is not ad.Zero for t in tangents]
@ -946,7 +1250,7 @@ def _shard_map_jvp_post_process(trace, out_tracers, mesh, in_names,
ad.JVPTrace.post_process_shard_map = _shard_map_jvp_post_process
def _shard_map_partial_eval(trace, shard_map_p, f, tracers, mesh, in_names,
out_names_thunk, check_rep, auto):
out_names_thunk, check_rep, rewrite, auto):
in_pvals = [t.pval for t in tracers]
in_knowns, in_avals, in_consts = pe.partition_pvals(in_pvals)
unk_in_names, known_in_names = pe.partition_list(in_knowns, in_names)
@ -966,7 +1270,7 @@ def _shard_map_partial_eval(trace, shard_map_p, f, tracers, mesh, in_names,
known_params = dict(mesh=mesh, in_names=(*known_in_names,),
out_names_thunk=known_out_names, check_rep=check_rep,
auto=auto)
rewrite=rewrite, auto=auto)
out = shard_map_p.bind(f_known, *in_consts, **known_params)
out_knowns, out_avals_sharded, jaxpr, env = aux()
out_consts, res = split_list(out, [len(out) - len(jaxpr.constvars)])
@ -980,7 +1284,7 @@ def _shard_map_partial_eval(trace, shard_map_p, f, tracers, mesh, in_names,
unk_arg_tracers = [t for t in tracers if not t.is_known()]
unk_params = dict(mesh=mesh, in_names=unk_in_names,
out_names=unk_out_names, jaxpr=jaxpr, check_rep=False,
auto=auto)
rewrite=rewrite, auto=auto)
out_avals = map(partial(_unshard_aval, mesh), unk_out_names, out_avals_sharded)
out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(a), None)
for a in out_avals]
@ -992,7 +1296,7 @@ def _shard_map_partial_eval(trace, shard_map_p, f, tracers, mesh, in_names,
pe.JaxprTrace.process_shard_map = _shard_map_partial_eval
def _shard_map_partial_eval_post_process(
trace, tracers, mesh, in_names, out_names_thunk, check_rep, auto):
trace, tracers, mesh, in_names, out_names_thunk, check_rep, rewrite, auto):
del check_rep
unk_tracers = [t for t in tracers if not t.is_known()]
jaxpr, res, env = pe.tracers_to_jaxpr([], unk_tracers)
@ -1013,7 +1317,7 @@ def _shard_map_partial_eval_post_process(
staged_in_names = ({0: (*mesh.axis_names,)},) * len(res) + ({},) * len(env)
staged_params = dict(jaxpr=jaxpr_, mesh=mesh, in_names=staged_in_names,
out_names=(*out_names_unknown,), check_rep=False,
auto=auto)
rewrite=rewrite, auto=auto)
out_avals = map(partial(_unshard_aval, mesh), out_names_unknown, out_avals_)
out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(a), None)
@ -1053,10 +1357,11 @@ def _promote_scalar_residuals_jaxpr(jaxpr, res):
return jaxpr, res
def _shard_map_transpose(out_cts, *args, jaxpr, mesh, in_names, out_names,
check_rep, auto):
check_rep, rewrite, auto):
mb_div = lambda x, y: x / y if y != 1 else x
out_cts = [ad.Zero(_shard_aval(mesh, ns, x.aval)) if type(x) is ad.Zero
else mb_div(x, math.prod(map(mesh.shape.get, _unmentioned(mesh, ns))))
else x if rewrite
else mb_div(x, prod(map(mesh.shape.get, _unmentioned(mesh, ns))))
for ns, x in zip(out_names, out_cts)]
args = [x if type(x) is not ad.UndefinedPrimal else
ad.UndefinedPrimal(_shard_aval(mesh, ns, x.aval))
@ -1072,9 +1377,10 @@ def _shard_map_transpose(out_cts, *args, jaxpr, mesh, in_names, out_names,
out = ad.backward_pass(
jaxpr_unknown.jaxpr, (), False, (), (*res_reshaped, *undefs), out_cts
)
return [ad.Zero(_unshard_aval(mesh, ns, x.aval)) if type(x) is ad.Zero
else jax.lax.psum(x, tuple(_unmentioned(mesh, ns)))
out = [ad.Zero(_unshard_aval(mesh, ns, x.aval)) if type(x) is ad.Zero
else x if rewrite else jax.lax.psum(x, tuple(_unmentioned(mesh, ns)))
for ns, x in zip(in_names, out)]
return out
fun_trans, nz_arg_cts = ad.nonzero_outputs(fun_trans)
fun_trans_flat, out_tree = flatten_fun_nokwargs(fun_trans, in_tree)
@ -1088,7 +1394,8 @@ def _shard_map_transpose(out_cts, *args, jaxpr, mesh, in_names, out_names,
out_flat = shard_map_p.bind(
fun_trans_flat, *all_args, mesh=mesh, in_names=tuple(new_in_names),
out_names_thunk=new_out_names_thunk, check_rep=check_rep, auto=auto)
out_names_thunk=new_out_names_thunk, check_rep=check_rep, rewrite=rewrite,
auto=auto)
return tree_unflatten(out_tree(), out_flat)
ad.primitive_transposes[shard_map_p] = _shard_map_transpose
@ -1201,7 +1508,6 @@ def _shard_map_dce(used_outputs: list[bool], eqn: core.JaxprEqn
return used_inputs, new_eqn
pe.dce_rules[shard_map_p] = _shard_map_dce
# Implementing pmap in terms of shard_map
def pmap(f, axis_name=None, *, in_axes=0, out_axes=0,
@ -1277,3 +1583,148 @@ def _get_devices(p, backend):
if jax.process_count() > 1:
return devs[:p.global_axis_size]
return devs[:p.local_axis_size]
### Rewrite!
class RewriteTracer(core.Tracer):
rep: set[AxisName]
val: Val
def __init__(self, trace, rep, val):
self._trace = trace
self.rep = rep
self.val = val
@property
def aval(self) -> core.AbstractValue:
return core.get_aval(self.val)
def full_lower(self) -> RewriteTracer:
return self
def __str__(self) -> str:
return str(self.val) # TODO(mattjj): could show replication info here
class RewriteTrace(core.Trace):
mesh: Mesh
dyna: int
def __init__(self, *args, mesh, dyna):
super().__init__(*args)
self.mesh = mesh
self.dyna = dyna
def pure(self, val) -> RewriteTracer:
return RewriteTracer(self, set(self.mesh.axis_names), val)
def lift(self, tracer: core.Tracer) -> RewriteTracer:
return RewriteTracer(self, set(self.mesh.axis_names), tracer)
def sublift(self, tracer: core.Tracer) -> RewriteTracer:
return RewriteTracer(self, tracer.rep, tracer.val)
def process_primitive(self, prim, in_tracers, params):
rule = _rewrite_rules.get(prim, partial(_rule_missing, prim))
in_vals, in_reps = unzip2((t.val, t.rep) for t in in_tracers)
with core.new_dynamic(self.dyna):
out_vals, out_reps = rule(self.mesh, in_reps, *in_vals, **params)
out_tracers = map(partial(RewriteTracer, self), out_reps, out_vals)
return out_tracers if prim.multiple_results else out_tracers[0]
def process_call(self, call_primitive, f, in_tracers, params):
in_vals, in_reps = unzip2((t.val, t.rep) for t in in_tracers)
f, out_reps = _rewrite_subtrace(f, self.main, tuple(in_reps))
with core.new_dynamic(self.dyna):
out_vals = call_primitive.bind(f, *in_vals, **params)
return map(partial(RewriteTracer, self), out_reps(), out_vals)
def post_process_call(self, call_primitive, out_tracers, params):
assert False # unreachable
def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros):
if symbolic_zeros:
msg = "Please open an issue at https://github.com/google/jax/issues !"
raise NotImplementedError(msg)
in_vals, in_reps = unzip2((t.val, t.rep) for t in tracers)
fun, out_reps1 = _rewrite_subtrace(fun, self.main, in_reps)
jvp, out_reps2 = _rewrite_subtrace(jvp, self.main, in_reps * 2)
with core.new_dynamic(self.dyna):
out_vals = prim.bind(fun, jvp, *in_vals, symbolic_zeros=symbolic_zeros)
fst, out_reps = lu.merge_linear_aux(out_reps1, out_reps2)
if not fst:
assert out_reps == out_reps[:len(out_reps) // 2] * 2
out_reps = out_reps[:len(out_reps) // 2]
return map(partial(RewriteTracer, self), out_reps, out_vals)
def post_process_custom_jvp_call(self, out_tracers, jvp_was_run):
assert False # unreachable
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees,
symbolic_zeros):
if symbolic_zeros:
msg = "Please open an issue at https://github.com/google/jax/issues !"
raise NotImplementedError(msg)
in_vals, in_reps = unzip2((t.val, t.rep) for t in tracers)
fun, out_reps1 = _rewrite_subtrace(fun, self.main, in_reps)
fwd_in_reps = [r_ for r in in_reps for r_ in [r, set(self.mesh.axis_names)]]
fwd, out_reps2 = _rewrite_subtrace(fwd, self.main, fwd_in_reps)
bwd = _rewrite_bwd(bwd, self.mesh, out_reps2, in_reps)
with core.new_dynamic(self.dyna):
out_vals = prim.bind(fun, fwd, bwd, *in_vals, out_trees=out_trees,
symbolic_zeros=symbolic_zeros)
fst, out_reps = lu.merge_linear_aux(out_reps1, out_reps2)
if not fst:
_, res_tree = out_trees()
_, out_reps = split_list(out_reps, [res_tree.num_leaves])
return map(partial(RewriteTracer, self), out_reps, out_vals)
def post_process_custom_vjp_call(self, out_tracers, _):
assert False # unreachable
# TODO process_axis_index
@lu.transformation
def _efficient_transpose_rewrite(mesh, in_names, out_names_thunk, *args):
in_reps = map(partial(_in_names_to_rep, mesh), in_names)
lvl = core.dynamic_level()
with core.new_main(RewriteTrace, dynamic=True, mesh=mesh, dyna=lvl) as main:
t = main.with_cur_sublevel()
in_tracers = map(partial(RewriteTracer, t), in_reps, args)
ans = yield in_tracers, {}
out_tracers = map(t.full_raise, ans)
out_vals, out_reps = unzip2((t.val, t.rep) for t in out_tracers)
del main, t, in_tracers, out_tracers, ans
out_rep_dst = [frozenset(_unmentioned(mesh, n)) for n in out_names_thunk()]
out_vals = [pbroadcast(x, tuple(n for n in src if n not in dst)) if src - dst
else x for x, src, dst in zip(out_vals, out_reps, out_rep_dst)]
yield out_vals
@lu.transformation_with_aux
def _rewrite_subtrace(main, in_reps, *in_vals):
assert len(in_reps) == len(in_vals), (len(in_reps), len(in_vals))
t = main.with_cur_sublevel()
in_tracers = map(partial(RewriteTracer, t), in_reps, in_vals)
with core.new_dynamic(main.level):
outs = yield in_tracers, {}
out_tracers = map(t.full_raise, outs)
out_vals, out_reps = unzip2((t.val, t.rep) for t in out_tracers)
yield out_vals, out_reps
def _rewrite_bwd(bwd, mesh, in_reps, reps_dst):
def new_bwd(*args):
lvl = core.dynamic_level()
with core.new_main(RewriteTrace, dynamic=True, mesh=mesh, dyna=lvl) as main:
bwd_, reps_thunk = _rewrite_subtrace(lu.wrap_init(bwd), main, in_reps())
out = bwd_.call_wrapped(*args)
del main
return map(_match_replication, reps_thunk(), reps_dst, out)
return new_bwd
def _match_replication(src, dst, x):
if dst - src:
x, = psum2_p.bind(x, axes=tuple(n for n in dst if n not in src),
axis_index_groups=None)
if src - dst:
x = pbroadcast(x, tuple(n for n in src if n not in dst))
return x

View File

@ -37,6 +37,7 @@ from jax._src import xla_bridge
from jax._src.util import safe_zip, safe_map, partition_list, merge_lists
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src import linear_util as lu
from jax._src import tree_util
import jax.numpy as jnp
@ -113,9 +114,12 @@ class ShardMapTest(jtu.JaxTestCase):
mesh, a, _ = create_inputs(P('z', ('x', 'y')), P(None, None))
assert a.device_buffers[0].shape == (4, 2)
# NOTE(mattjj): to use out_specs=P(None, ('x', 'y')), we need to use
# all_gather_invariant primitive, which differs in its output replication
# type compared to all_gather.
@jax.jit
@partial(shard_map, mesh=mesh,
in_specs=(P('z', ('x', 'y')),), out_specs=P(None, ('x', 'y')))
in_specs=(P('z', ('x', 'y')),), out_specs=P('z', ('x', 'y')))
def fwd(a):
return lax.all_gather(a, 'z', axis=0, tiled=True)
@ -559,8 +563,7 @@ class ShardMapTest(jtu.JaxTestCase):
def test_partial_eval_custom_axis_env(self):
mesh = Mesh(jax.devices(), ('i',))
@partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i'),
check_rep=False) # check_rep=False b/c no scan rep rule yet
@partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i'))
def f(_):
_, idx = jax.lax.scan(lambda _, __: (None, jax.lax.axis_index('i')),
None, None, length=1)
@ -861,6 +864,274 @@ class ShardMapTest(jtu.JaxTestCase):
self.assertIn("\"[('i',)]\"", mhlo_str)
self.assertIn("-> (tensor<1xf32> {jax.result_info = \"[('i',)]\"})", mhlo_str)
def test_rewrite_process_call(self):
def f(x):
return core.call_p.bind(lu.wrap_init(lambda x: [2. * x]), x)[0] * x
mesh = jtu.create_global_mesh((4,), ('x',))
g = shard_map(f, mesh, in_specs=(P('x'),), out_specs=P('x'))
x = jnp.arange(4.)
y = jax.jit(g)(x) # eager requires shmap to have ShardMapTrace.process_call
self.assertAllClose(y, 2 * x * x, check_dtypes=True)
def test_rewrite_post_process_call(self):
# We shouldn't hit post_process_call here because of RewriteTrace's dynamic
# behavior (i.e. no data dependence).
mesh = jtu.create_global_mesh((4,), ('x',))
@jax.jit
@partial(shard_map, mesh=mesh, in_specs=(P('x'),), out_specs=P('x'))
def f(x):
return core.call_p.bind(lu.wrap_init(lambda: [2. * x]))[0] * x
x = jnp.arange(4.)
y = f(x)
self.assertAllClose(y, 2 * x * x, check_dtypes=True)
def test_rewrite_process_custom_jvp_call(self):
@jax.custom_jvp
def foo(x):
return 2. * x
@foo.defjvp
def foo_jvp(primals, tangents):
(x,), (x_dot,) = primals, tangents
return foo(x), 2. * x_dot
mesh = jtu.create_global_mesh((4,), ('x',))
g = shard_map(lambda x: foo(x) * x, mesh,
in_specs=(P('x'),), out_specs=P('x'))
x = jnp.arange(4.)
y = jax.jit(g)(x)
self.assertAllClose(y, 2 * x * x, check_dtypes=True)
y2, y_dot = jax.jvp(jax.jit(g), (x,), (3 * x,))
self.assertAllClose(y2, 2 * x * x, check_dtypes=True)
self.assertAllClose(y_dot, 2 * 2 * 3 * x * x, check_dtypes=True)
def test_rewrite_process_custom_vjp_call(self):
@jax.custom_vjp
def foo(x):
return 2. * x
def foo_fwd(x):
return foo(x), None
def foo_bwd(_, y_bar):
return 2. * y_bar,
foo.defvjp(foo_fwd, foo_bwd)
mesh = jtu.create_global_mesh((4,), ('x',))
g = shard_map(lambda x: foo(x) * x, mesh,
in_specs=(P('x'),), out_specs=P('x'))
x = jnp.arange(4.)
y = jax.jit(g)(x)
self.assertAllClose(y, 2 * x * x, check_dtypes=True)
y_, x_bar = jax.value_and_grad(lambda x: jax.jit(g)(x).sum())(x)
self.assertAllClose(y_, (2 * x * x).sum(), check_dtypes=True)
self.assertAllClose(x_bar, 2 * 2 * x, check_dtypes=True)
def test_rewrite_process_custom_vjp_call_match_more_replicated(self):
@jax.custom_vjp
def foo(x):
return 2. * x
def foo_fwd(x):
return foo(x), None
def foo_bwd(_, y_bar):
return jnp.ones_like(y_bar), # diff! more replicated than primal/tangent
foo.defvjp(foo_fwd, foo_bwd)
mesh = jtu.create_global_mesh((4,), ('x',))
g = shard_map(lambda x: foo(x) * x, mesh,
in_specs=(P('x'),), out_specs=P('x'))
x = jnp.arange(4.)
y = jax.jit(g)(x)
self.assertAllClose(y, 2 * x * x, check_dtypes=True)
y_, x_bar = jax.value_and_grad(lambda x: jax.jit(g)(x).sum())(x)
self.assertAllClose(y_, (2 * x * x).sum(), check_dtypes=True)
self.assertAllClose(x_bar, jnp.ones_like(x) + 2 * x, check_dtypes=True)
def test_rewrite_process_custom_vjp_call_match_less_replicated(self):
@jax.custom_vjp
def foo(x, y):
del y
return 2. * x
def foo_fwd(x, y):
return foo(x, y), y
def foo_bwd(y, _):
return y, None # diff! x_bar less replicated than primal/tangent
foo.defvjp(foo_fwd, foo_bwd)
mesh = jtu.create_global_mesh((4,), ('x',))
g = shard_map(lambda x, y: foo(x, y) * y, mesh,
in_specs=(P(), P('x')), out_specs=P('x'))
x = jnp.arange(4.)
y = jnp.arange(4 * 4.)
z = jax.jit(g)(x, y)
self.assertAllClose(z, 2 * jnp.tile(x, (4,)) * y, check_dtypes=False)
z_, x_bar = jax.value_and_grad(lambda x, y: jax.jit(g)(x, y).sum())(x, y)
self.assertAllClose(z.sum(), z_, check_dtypes=False)
self.assertAllClose(x_bar, jnp.arange(16).reshape(4, 4).sum(0),
check_dtypes=False)
def test_rewrite_custom_vjp_call_jaxpr(self):
@jax.custom_vjp
def foo(x):
return 2. * x
def foo_fwd(x):
return foo(x), None
def foo_bwd(_, y_bar):
return 2. * y_bar,
foo.defvjp(foo_fwd, foo_bwd)
def foo_scan(x):
y, _ = jax.lax.scan(lambda x, _: (foo(x), None), x, None, length=1)
return y
mesh = jtu.create_global_mesh((4,), ('x',))
g = shard_map(lambda x: foo_scan(x) * x, mesh,
in_specs=(P('x'),), out_specs=P('x'))
x = jnp.arange(4.)
y = jax.jit(g)(x)
self.assertAllClose(y, 2 * x * x, check_dtypes=True)
y_, x_bar = jax.value_and_grad(lambda x: jax.jit(g)(x).sum())(x)
self.assertAllClose(y_, (2 * x * x).sum(), check_dtypes=True)
self.assertAllClose(x_bar, 2 * 2 * x, check_dtypes=True)
def test_transpose_identity(self):
mesh = jtu.create_global_mesh((4,), ('x',))
@partial(shard_map, mesh=mesh, in_specs=P(), out_specs=P())
def f(x):
return x
jaxpr = jax.make_jaxpr(jax.vjp(f, 1.)[1])(1.)
e, = jaxpr.jaxpr.eqns
self.assertEmpty(e.params['jaxpr'].eqns)
jaxpr = jax.make_jaxpr(jax.vjp(jax.vjp(f, 1.)[1], 1.)[1])((1.,))
e, = jaxpr.jaxpr.eqns
self.assertEmpty(e.params['jaxpr'].eqns)
@partial(shard_map, mesh=mesh, in_specs=P(), out_specs=P())
def g(x):
return jax.jit(lambda x: x)(x)
jaxpr = jax.make_jaxpr(jax.vjp(g, 1.)[1])(1.)
e, = jaxpr.jaxpr.eqns
e1, e2 = e.params['jaxpr'].eqns
self.assertEmpty(e1.outvars)
self.assertEmpty(e2.params['jaxpr'].eqns)
def test_fanout_specs_transpose_to_psum(self):
mesh = jtu.create_global_mesh((4,), ('x',))
@partial(shard_map, mesh=mesh, in_specs=P(), out_specs=P('x'))
def f(x):
return x
jaxpr = jax.make_jaxpr(jax.vjp(f, jnp.arange(1.))[1])(jnp.arange(4.))
e, = jaxpr.jaxpr.eqns
e2, = e.params['jaxpr'].eqns
self.assertEqual(str(e2.primitive), 'psum2')
self.assertEqual(e2.params['axes'], ('x',))
def test_fanin_psum_transposes_to_fanout(self):
mesh = jtu.create_global_mesh((4,), ('x',))
@partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P())
def f(x):
return jax.lax.psum(x, 'x')
jaxpr = jax.make_jaxpr(jax.vjp(f, jnp.arange(4.))[1])(jnp.array([1.]))
e, = jaxpr.jaxpr.eqns
e1, = e.params['jaxpr'].eqns
self.assertEqual(str(e1.primitive), 'pbroadcast')
def test_psum_with_implicit_fanout_self_transposes(self):
mesh = jtu.create_global_mesh((4,), ('x',))
@partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P('x'))
def f(x):
return jax.lax.psum(x, 'x')
jaxpr = jax.make_jaxpr(jax.vjp(f, jnp.arange(4.))[1])(jnp.arange(4.))
e, = jaxpr.jaxpr.eqns
e1, e2 = e.params['jaxpr'].eqns
self.assertEqual(str(e1.primitive), 'psum2')
self.assertEqual(str(e2.primitive), 'pbroadcast')
def test_rewrite_binops(self):
mesh = jtu.create_global_mesh((4,), ('x',))
@partial(shard_map, mesh=mesh, in_specs=(P(), P('x')), out_specs=P('x'))
def f(x, y):
return x * y
jaxpr = jax.make_jaxpr(f)(jnp.arange(1.), jnp.arange(4.))
e, = jaxpr.jaxpr.eqns
e = e.params['jaxpr'].eqns[0]
self.assertEqual(e.primitive.name, 'pbroadcast')
self.assertEqual(e.params['axes'], ('x',))
def test_rewrite_scan(self):
mesh = jtu.create_global_mesh((4,), ('x',))
@partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P('x'))
def f(x):
x, _ = jax.lax.scan(lambda x, _: (jax.lax.psum(x, 'x'), None), x, None,
length=2)
return x
jaxpr = jax.make_jaxpr(f)(jnp.arange(4.))
e, = jaxpr.jaxpr.eqns
e, = e.params['jaxpr'].eqns
e1, e2 = e.params['jaxpr'].eqns
self.assertEqual(e1.primitive.name, 'psum2')
self.assertEqual(e2.primitive.name, 'pbroadcast')
def test_check_rep_false_grads(self):
# This test is redundant with the systematic tests below, but it serves as a
# direct regression test for a bug.
mesh = jtu.create_global_mesh((4,), ('heads',))
def f(q, k, v):
def body(q, k, v):
return q * k[None, :] + v[None, :]
out = shard_map(body, mesh, check_rep=False,
in_specs=(q_spec, kv_spec, kv_spec,),
out_specs=q_spec)(q, k, v)
return out.sum()
q_spec = P('heads', None)
kv_spec = P(None)
q = jax.device_put(jnp.arange(32.).reshape(4, 8), jax.sharding.NamedSharding(mesh, q_spec))
k = jax.device_put(jnp.arange(8.), jax.sharding.NamedSharding(mesh, kv_spec))
v = jax.device_put(jnp.arange(8.), jax.sharding.NamedSharding(mesh, kv_spec))
jtu.check_grads(f, (q, k, v), order=1, modes=['rev'], rtol=1e-2)
class FunSpec(NamedTuple):
name: str
@ -1126,12 +1397,15 @@ class ShardMapSystematicTest(jtu.JaxTestCase):
self.assertAllClose(expected, out, check_dtypes=False)
@parameterized.named_parameters(
sample(config.FLAGS.jax_num_generated_cases, sample_shmap))
(name + f'_check_rep={check_rep}', *params, check_rep)
for (name, *params) in sample(config.FLAGS.jax_num_generated_cases, sample_shmap)
for check_rep in [True, False]
)
@jax.default_matmul_precision("float32")
def test_grads(self, fun, mesh, jit, in_specs, out_specs, args, _):
def test_grads(self, fun, mesh, jit, in_specs, out_specs, args, _, check_rep):
mesh = self.make_mesh(mesh)
args = map(jnp.array, args)
f = shard_map(fun, mesh, in_specs, out_specs)
f = shard_map(fun, mesh, in_specs, out_specs, check_rep=check_rep)
if jit:
f = jax.jit(f)
jtu.check_grads(f, args, order=2, atol=1e-2, rtol=1e-2)