From 70b58bbd30cf8ee58f821649c3c472fb37ae151e Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Thu, 31 Aug 2023 17:30:34 -0700 Subject: [PATCH] rolling forward shard_map transpose fixes The new efficient-transpose path, enabled by setting check_rep=True in the shard_map call, had kept working. But the change inadvertently broke the check_rep=False path. And because most tests set check_rep=True, we didn't notice it in the tests! The issue was that with check_rep=False, we need the shard_map transpose rule to insert psums corresponding to in_specs with fan-out, and correspondingly insert division for out_specs with fan-in-consensus. (With the new check_rep=True path that this change adds, those extra operations aren't necessary as the body itself transposes correctly.) But the PR accidentally removed those! The fix was simple: just track whether we've applied the efficient-transpose-body-rewrite (i.e. whether we're in the new body-is-transposable path or old need-extra-operations path) by adding a boolean parameter `rewrite` to the shard_map primitive, and if the rewrite hasn't been applied then include the explicit psum/div operations in the transpose rule. Reverts 8a04dfd830ff89f46e1fe3e866ee4fb2da9c90aa PiperOrigin-RevId: 561805840 --- docs/jep/17111-shmap-transpose.md | 515 ++++++++++++++++++++++ docs/jep/index.rst | 1 + jax/_src/core.py | 17 +- jax/_src/custom_derivatives.py | 1 + jax/_src/interpreters/ad.py | 1 + jax/_src/lax/lax.py | 11 + jax/_src/lax/parallel.py | 1 + jax/experimental/jax2tf/jax2tf.py | 5 + jax/experimental/shard_map.py | 687 +++++++++++++++++++++++++----- tests/shard_map_test.py | 286 ++++++++++++- 10 files changed, 1399 insertions(+), 126 deletions(-) create mode 100644 docs/jep/17111-shmap-transpose.md diff --git a/docs/jep/17111-shmap-transpose.md b/docs/jep/17111-shmap-transpose.md new file mode 100644 index 000000000..47da5ce87 --- /dev/null +++ b/docs/jep/17111-shmap-transpose.md @@ -0,0 +1,515 @@ +# Efficient transposition of replication-inducing collectives +*mattjj@*, *dougalm@* + +*August 2023* + +## Motivation + +We have an efficiency problem in automatically transposing `shmap`s containing +certain collectives. The issue arises with `psum` and `all_gather`, specifically +when the output of the collective is returned to the caller as an unmapped +output. And it's not an edge case: for example, it arises when applying `grad` +to a `shmap`-based batch data parallel neural network loss function which uses +`psum` to compute the total loss. + +We've known about this problem for some time. An analogous issue exists with +`pmap`, though it's been worked around by keeping `grad` inside `pmap` rather than +outside. A primary goal of the incomplete avals-with-names work was to address a +version of this transpose efficiency problem. This doc draws on those ideas, +while extending and revising them to handle more cases and to be much easier to +land. Indeed the solution proposed here only affects the `shmap` implementation. +The rest of the system need not be changed (yet). + +The main purpose of this doc is to define this transpose efficiency problem and +propose an easy-to-land solution. + +This doc is not about: +* logical axis names on arrays (the only axis names here are just like in + `shmap` and OG `pmap`); +* changing autodiff semantics (all the numbers and (non)errors are staying the + same, we're just making things more efficient); +* allowing user code to reflect on any new information, or really affecting user + code at all. + +## Problem: efficient transpose of `psum` or `all_gather` depends on whether cotangents are invariant across devices + +Consider this semi-realistic example, meant to resemble a replicated-parameter +batch data parallel loss function: + +```python +devices = jax.devices() # 8 devices + +@partial(shmap, mesh=Mesh(devices, ('batch',)), + in_specs=(P(None, None), P('batch', None)), + out_specs=P()) +def loss(params, batch): + inputs, targets = batch + predictions = predict(params, inputs) + local_loss = jnp.mean(jnp.sum(predictions - targets, -1)) + global_loss = lax.pmean(local_loss, 'batch')) + return global_loss +``` + +Notice the `out_specs=P()`, which indicates an unmapped output. If you're not +familiar with the notion of unmapped outputs, see the appendix at the bottom of +this document. + +Most of the details in the `loss` example aren't important. All that matters for +our purposes is that we're applying `psum` (or rather `pmean = lambda x, name: +psum(x, name) / psum(1, name)`) at the end. So a distilled version looks like +this: + +```python +# Example 1: shmap involving psum and unmapped output with inefficient transpose +f1 = shmap(lambda x: psum(g(x), 'i'), + in_specs=P('i'), out_specs=P()) +``` + +We even simplified notation by suppressing the `mesh` argument. In the examples to +follow it can be inferred from context. + +What does the transpose look like? Writing `t` to mean function transpose, we +could evaluate `t(f1)(ybar)` for any `ybar` efficiently by applying the function +`¿f1_transpose?` below: + +```python +# An efficient "transpose" of Example 1 (but don't transpose this again!) +¿f1_transpose? = shmap(t(g), in_specs=P(), out_specs=P('i')) +``` + +But that's not the transpose we currently get as t(f1). + +Instead, the current recipe for transposition is roughly that we switch +`in_specs` and `out_specs`, do some division rescaling for unmapped outputs, and +transpose the body. Because `psum` is its own transpose (as an all-reduce sum), +we end up producing this transpose: + +```python +# The transpose we currently get for Example 1 (which is fine to transpose again) +t(f1) = shmap(lambda ybar: t(g)(psum(ybar / 8, 'i')), + in_specs=P(), out_specs=P('i')) +``` + +This transpose gets the numbers right, but it's wasteful. We know statically +from the transpose's `in_specs=P()` that `ybar` has the same value for each function +instance, i.e. that its value is device-invariant for devices along the mesh +axis named `i`, and yet we apply a `psum` to it! That uses expensive communication +just to multiply the value on each device by 8. (Here 8 refers to the size of +axis i. The division by 8 comes from the original function's `out_specs=P()`; it +and the trivial `psum` basically cancel each other out.) + +What are we doing wrong? We're not exploiting the fact that cotangents `ybar` +corresponding to `f1`'s unmapped outputs are guaranteed to be device-invariant; +instead, we're defensively `psum`ming them as if they weren't because `psum`'s +transpose can't be sure given the local information it has. Sometimes the `psum` +is necessary, as in transposing `f2` with respect to its first argument: + +```python +# Example 2: shmap involving psum and *mapped* output with efficient transpose +f2 = shmap(lambda x, y: psum(g(x), 'i') * y, + in_specs=(P('i'), P('i')), out_specs=P('i')) + +# The transpose we currently get for Example 2 is efficient +t(f2, 0) = shmap(lambda y, zbar: t(g)(psum(zbar * y, 'i')), + in_specs=(P('i'), P('i')), out_specs=P('i')) +``` + +Intuitively, if our transpose machinery could tell the difference between +Example 1 and Example 2, we could do better by avoiding the psum and division +where possible. + +The inefficient examples can be even smaller. Consider transposing this cursed +identity function: + +```python +# Example 3: cursed identity +cursed_identity = shmap(lambda x: x, P(), P()) + +# Currently we get these inefficient transposes +t(cursed_identity) = shmap(lambda x: psum(x / 8, 'i'), P(), P()) +t(t(cursed_identity)) = shmap(lambda x: psum(psum(x / 8 / 8, 'i'), 'i')), P(), P()) +... +``` + +It keeps getting bigger the more we transpose. How embarrassing! + +And `psum` isn't the only culprit. Something analogous holds true for +`all_gather`: + +```python +# Example 4: all_gather to an unmapped output +f4 = shmap(lambda x: all_gather(x, 'i'), P('i'), P()) + +# Currently we get this inefficient transpose +t(f4) = shmap(lambda ybar: psum_scatter(ybar / 8, 'i'), P(), P('i')) +``` + +This program is a bit artificial. Why do an `all_gather` and feed the result into +an unmapped output, rather than skipping the `all_gather` in the body and just +using `out_specs=P('i')` to collect the results? But even though it's cooked-up, +this example nevertheless exhibits a transpose which unnecessarily performs +communication (we could have just performed a non-communicating slice), +analogous to Example 1 for `psum`. + +Also analogously to the `psum` examples, the defensive `psum_scatter` is +necessary in some cases: + +```python +# Example 5: all_gather to a mapped output +f5 = shmap(lambda x, y: all_gather(x, 'i') * y, + in_specs=(P('i'), P('i')), out_specs=P('i')) + +# Currently we get this efficient transpose +t(f5, 0) = shmap(lambda y, zbar: psum_scatter(zbar * y, 'i'), + in_specs=(P('i'), P('i')), out_specs=P('i')) +``` + +So how do we avoid these inefficient transposes? + +## Solutions + +Here are two solution ideas. They aren't mutually exclusive. But (spoilers) the +second one is better, and it's all we need. + +### Partial solution "P-sum": build the ability to express a `psum` into `out_specs` + +This solution is a bit of a strawperson because it would offer only an awkward +way to write programs. And it wouldn't even fix everything! But it's worth +considering, if only to motivate a more complete solution. + +Example 4 above is artificial because we could have just used `out_specs` instead +of an `all_gather` in the body: + +```python +# Example 4 again +f4 = shmap(lambda x: all_gather(x, 'i'), P('i'), P()) + +# Why didn't we just write it like this? +f4_better = shmap(lambda x: x, P('i'), P('i')) +``` + +The `f4_better` version doesn't have any transposition problems, since the +transpose problems arise from collectives in the body. + +Analogously, we could fix Example 1 by extending `out_specs` so that they can +express summing: + +```python +# Example 1 again +f1 = shmap(lambda x: psum(g(x), 'i'), + in_specs=P('i'), out_specs=P()) + +# What if we could write an output sum like this? +f1_better = shmap(g, in_specs=P('i'), out_specs=P(sum='i')) # sum='i' means sum over that axis + +# Then it could transpose like this: +t(f1_better) = shmap(t(g), in_specs=P(), out_specs=P('i')) +t(t(f1_better)) = shmap(t(t(g)), in_specs=P('i'), P(sum='i')) +``` + +So offering `psum`s built into `out_specs` fixes the transpose problem of +Example 1. But it doesn't fully fix the cursed identity transpose in Example 3: + +```python +# Example 3 again +cursed_identity = shmap(lambda x: x, P(), P()) + +# How it would transpose with the P-sum partial solution: +t(cursed_identity) = shmap(lambda x: x / 8, P(), P(sum='i')) +t(t(cursed_identity)) = shmap(lambda x: x / 8, P(), P(sum='i')) +``` + +It's an improvement since the program doesn't continue to get bigger as we keep +transposing, but we're still doing wasteful communication. + +### Full solution: statically track device-varying vs device-invariant intermediates, plus new primitives + +This solution has two components: +1. track when values are guaranteed to be device-invariant vs device-varying + over particular mesh axes, and +2. decompose `psum` into a two-step process, introducing a new `pbroadcast` + primitive, and introduce new primitives for `all_gather` and its transposes. + +Morally, the tracking of device-invariant vs device-varying information is a +type-level consideration. But for the expedience of our first implementation, we +don't need to literally add the information to abstract values or jaxpr types. +Before we get to implementation, we'll first introduce the idea using types. + +Also to follow is a discussion of making the user API convenient and backward +compatible. But to first introduce the idea, we'll ignore convenience and +instead write code that is as explicit as possible. + +#### Tracking device invariance in avals (a.k.a. avals-with-names, revived) + +We can sometimes tell from static information alone that the values of some +intermediate variables in the body of a `shmap` are guaranteed to be invariant +along a mesh axis, in the sense that the function instances (and their +corresponding devices) along the mesh axis must all be computing with the same +value. We'll call such values device-invariant. For values that are not +device-invariant, we'll say they're device-varying, though really we mean +potentially device-varying from the point of view of the type system. + +To encode device variance in types, we'll extend the syntax of types for arrays. +We'll write things like `x:f32[3,4]{i}` to indicate that `x` is (potentially) +device-varying along mesh axis `i` (and device-invariant over any other mesh +axes of the `shmap`). More generally, we'll say the grammar for array type +syntax is something like + +``` +shaped_array ::= [, ...] +device_variance_type ::= {, ...} +``` + +We'll also update the typing rules to handle device variance types: +* for first-order primitives other than collectives + - for multi-arity primitives, the operand device variance types must be equal + where shapes must be equal, e.g. `mul x:f32[s1]{r1} y:f32[s2][r2]` requires + `r1 == r2` in addition to `s1 == s2` + - the output device variance type must be the same as the operand(s) +* for higher-order primitives + - we just instantiate any type variables including the device variance type + (and checking types for equality checks their device variance types are + equal) + - (when performing type inference, e.g. for branches of a `cond`, we take the + union of the sets of axis names in device variance types) +* for first-order collectives + - a collective can either accept a device-varying or device-invariant input + (along a mesh axis corresponding to its axis name parameter); it's an error + to pass a device-invariant operand to a collective which accepts + device-varying operands and vice-versa + - a collective can either produce a device-varying or device-invariant output + - see the table below +As a side benefit, whatever logic implements this type checking can subsume +`shmap`'s "static analysis" check for whether a `shmap` body function is +compatible with any unmapped `out_specs`. + +Here's a table summarizing the device variance typing for collective primitives: + +| Name | Device variance type | Example | Lowers to HLO | Transpose | +| --- | --- | --- | --- | --- | +| `psum2` | `Varying -> Invariant` | `y:f32[3]{j} = psum(x:f32[3]{i,j}, axis='i')` | `AllReduceSum` (communication) | `pbroadcast` | +| `pbroadcast` | `Invariant -> Varying` | `y:f32[3]{i} = pbroadcast(x:f32[3], 'i')` | no-op (no communication) | `psum` | +| `all_to_all` | `Varying -> Varying` | `y:f32[16]{i} = all_to_all(x:f32[16]{i}, 'i', 0, 0)` `AllToAll` (communication) | `all_to_all` | +| `axis_index` | `() -> Varying` | `idx:i32[]{i} = axis_index('i')` | `ReplicaId` and some arithmetic (no communication) | n/a | +| `psum_scatter` | `Varying -> Varying` | `y:f32[2]{i} = psum_scatter(x:f32[16]{i}, 'i')` | `ReduceScatterSum` (communication) | `all_gather` | +| `all_gather` | `Varying -> Varying` | `y:f32[16]{i} = all_gather(x:f32[2]{i}, 'i')` | `AllGather` (communication) | `psum_scatter` | +| `pscatter` | `Invariant -> Varying` | `y:f32[2]{i} = pscatter(x:f32[16], 'i')` | `lambda x: x[axis_index('i'), None]` (no communication) | `all_gather_invariant` | +| `all_gather_invariant` | `Varying -> Invariant` | `y:f32[16] = all_gather_invariant(x:f32[2]{i}, 'i')` | `AllGather` (communication) | `pscatter` | + + +There are some surprising things here! +* We introduced several new primitives, including + - `pbroadcast`, which interestingly lowers to a no-op + - `all_gather_invariant`, which lowers to the same thing as `all_gather` but + has a different device variance type (essentially `all_gather` has a + `pbroadcast` fused into it, whereas `all_gather_invariant` does not) + - `pscatter` which is the dual (transpose) of `all_gather_invariant` +* all_gather has a device-varying result + +Intuitively, the reason to introduce `pbroadcast` (other than to make the typing +rules work) is so that `psum` can transpose to a physical no-op. The reason we +need `all_gather` to have a device-varying result is so that we can transpose it +to `psum_scatter`; if we instead left it with a device-invariant result, we +might need a downstream `pbroadcast`, and that composition would transpose to an +inefficient `psum` followed by slicing / `pscatter`. So instead we have a +`pbroadcast` "fused into" the `all_gather`, thus allowing for an efficient +transpose into `psum_scatter`. We provide `all_gather_invariant` and its +transpose `pscatter` mainly for completeness; it's unlikely users will need it +(it corresponds to the situation in Example 4, which is easy to write +differently using `out_specs`). + +Interestingly, the `psum` and `pbroadcast` transpose pair correspond to the +`psum_idrev` and `id_psumrev` that users introduced while training LLMs with +`pmap`. + +#### How this system solves the inefficient transpose examples + +Consider again the simplified motivating example: + +```python +# Example 1 again +f1 = shmap(lambda x: psum(g(x), 'i'), + in_specs=P('i'), out_specs=P()) + +# Example 1 with intermediate device variance types annotated +@partial(shmap, in_specs=P('i'), out_specs=P()) +def f1(x: f32[3,4]{i}): + w:f32[]{i} = g(x) + y:f32[]{} = psum(w, 'i') + return y +``` + +With these new rules, the transpose is: + +```python +# Example 1 transpose using device variance types (go ahead and transpose this again!) +t(f1) = shmap(lambda ybar: t(g)(pbroadcast(ybar, 'i')), + in_specs=P(), out_specs=P('i')) + +# Example 1 transpose with intermediate device variance types annotated +@partial(shmap, in_specs=P('i'), out_specs=P()) +def f1_transpose(ybar: f32[]): + wbar:f32[]{i} = pbroadcast(ybar, 'i') + xbar:f32[3,4]{i} = transpose(g)(wbar) + return xbar +``` + +where evaluating the `pbroadcast` application involves no communication or FLOPs +at all; it's a no-op. Notice that if we keep transposing the body does not grow +in size; indeed `t(t(f1)) == f1`. Efficiency achieved! + +And we wouldn't mess up the other examples either, so long as we `pbroadcast` to +make the types check where needed: + +```python +# Example 2 rewritten with explicit pbroadcast +f2 = shmap(lambda x, y: pbroadcast(psum(g(x), 'i'), 'i') * y, + in_specs=(P('i'), P('i')), out_specs=P('i')) + +# Example 2 transpose using device variance types +t(f2, 0) = shmap(lambda y, zbar: t(g)(pbroadcast(psum(zbar * y, 'i'), 'i')), + in_specs=(P('i'), P('i')), out_specs=P('i')) + + +# Example 3 again +cursed_identity = shmap(lambda x: x, P(), P()) +# Notice here the body is `f32[...] -> f32[...]`, i.e. no device varying type. + +# Example 3 transpose using device variance types +t(cursed_identity) = shmap(lambda x: x, P(), P()) +t(t(cursed_identity)) = shmap(lambda x: x, P(), P()) +``` + +Intuitively, in Example 1 we now only have "half the original psum", whereas in +Example 2 we get both "halves". For Example 3 we never need any operations in +the body at all. + +For the `all_gather` examples, Example 4 would need to use +`all_reduce_invariant` to have an efficient transpose (though it'd be better to +instead use `out_specs` instead of the collective in the body): + +```python +# Example 4 rewritten with explicit all_reduce_invariant +f4 = shmap(lambda x: all_gather_invariant(x, 'i'), P('i'), P()) + +# Example 4 with intermediate device variance types annotated +@partial(shmap, P('i'), P()) +def f4(x:f32[1]{i}): + y:f32[8]{} = all_gather_invariant(x, 'i') + return y + +# Example 4 transpose with intermediate device variance types annotated +@partial(shmap, in_specs=P(), out_specs=P('i')) +def f4_transpose(ybar:f32[8]): + xbar:f32[1]{i} = pscatter(ybar, 'i') + return xbar +``` + +For Example 5, using the device-varying `all_gather` works as we'd want: + +```python +# Example 5 with intermediate device variance types annotated +@partial(shmap, in_specs=(P('i'), P('i')), out_specs=P('i')) +def f5(x:f32[1]{i}, y:f32[8]{i}): + z:f32[8]{i} = all_gather(x, 'i') + w:f32[8]{i} = z * y + return w + +# Transpose with respect to first argument +@partial(shmap, in_specs=(P('i'), P('i')), out_specs=P('i')) +def f5_transpose(y:f32[8]{i}, wbar:f32[8]{i}): + zbar:f32[8]{i} = wbar * y + xbar:f32[1]{i} = psum_scatter(zbar, 'i') + return xbar +``` + +### How to make the API convenient for users (and backward compatible) + +But what user wants to write `pbroadcast`s? And what developer wants to break +lots of existing user code involving `psum`s which are not fed into unmapped +outputs? Not me! + +Instead we can automatically insert the `pbroadcast`s. It's a bit analogous to how +we do automatic rank promotion at the `jax.numpy` layer, inserting broadcasts to +avoid rank mismatch errors in binary operators. But it's much simpler since we +don't need to contend with shape tuples. The typical rule is: whenever we see a +multi-arity operation where the operands disagree in their device variance +types, take the union of operands' device variance types' axis name sets and +insert `pbroadcast`s to lift each operand to the resulting device variance type. + +Automatically inserting `pbroadcast`s just before they're needed may mean we apply +the same `pbroadcast` to the same operand multiple times, creating common +subexpressions. When we transpose, those could turn into a sum-of-`psum`s rather +than a `psum`-of-sum. We'll rely on the compiler to clean that up as appropriate. +If it's a problem then we could add some simple memoization to the +`pbroadcast`-insertion pass. + +The user API for `all_gather` will mean `all_gather_p` by default (not +`all_gather_invariant_p`), covering the common case and meaning no `pbroadcast`s +must be inserted. + +We can provide an option on `shmap` to disable this automatic insertion of +`pbroadcast`s, in which case it'll be up to the user to ensure type-correctness. +This explicit option may be appealing to some who want to be explicit about +where the `psum`s occur in the backward pass. + +### How to implement the solution + +The key to making the implementation lightweight is that **we aren't going to +add these types to avals or jaxprs**. At least, not at first. That can be +expensive because it requires updating the rest of JAX, e.g. all consumers of +avals and jaxprs may need to handle the new types. We're not falling for that +again! + +Instead we're going to keep these extended types as metadata internal to +`shmap`, just like the current "replication checking for `out_specs`" machinery +is internal to `shmap`. Indeed this solution amounts to a relatively small +extension to that existing machinery: it was already tracking the same +information; now we're just adding the `pbroadcast`s. + +We have at least two options for where to perform the `pbroadcast` insertion: +1. just before transposition, in the transpose rule, where we have a jaxpr of + the computation to be transposed; +2. in every `shmap` body, whether eagerly executed or staged out, like the + current "replication checking for `out_specs`" machinery. +The former may end up being easier since we only have to handle the jaxpr case, +and only linear primitives. But we'll start by trying the latter so the +implementation here is a strict revision/extension to the existing +replication-checking logic. + +## Appendix: defining and motivating maps with unmapped inputs and outputs + +For concreteness, we'll mostly focus on `shmap`, though these same ideas apply +to e.g. `pmap` and probably `xmap`. + +An argument/input is _unmapped_ along a mesh axis when the corresponding entry +of `in_specs` doesn't mention that mesh axis's name. Logically it means that +each function instance along that mesh axis gets the same value for the +argument. To the caller, each operand is sliced according to the mesh axes over +which the operand is mapped, whereas there is no slicing for mesh axes over +which the operand is unmapped. + +An output is _unmapped_ along a mesh axis when the corresponding entry of +`out_specs` doesn't mention that mesh axis's name. Logically it means each +function instance along that mesh axis must return the same value. To the +caller, each result of the `shmap` is formed by concatenating the return values +of every function instance along which the outputs are mapped, whereas for mesh +axes over which the output is unmapped only one copy of the value is used. + +See [the `shmap` +JEP](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html) for examples +of unmapped inputs and outputs. For comparison, in `vmap` unmapped +inputs/outputs are indicated by using `in_axes` / `out_axes` of `None` (rather +than an `int`). + +Here are reasons we like unmapped inputs and outputs for `shmap`: +* **Same expressiveness as `pjit`.** Anything `pjit` can do, the `shmap` escape + hatch should be able to do too. Or else we'd have a lacking escape hatch! If + we didnt have unmapped outputs in `shmap` then we couldn't express the same + batch-parallel loss function computations as `pjit`. +* **Closed-over inputs.** Closed-over inputs essentially correspond to unmapped + inputs, and... +* **Closure under transposition.** Once we have unmapped inputs, it's natural to + be able to transpose to unmapped outputs. + +So unmapped outputs are both canonical and useful! diff --git a/docs/jep/index.rst b/docs/jep/index.rst index 8456266d7..30ca8a846 100644 --- a/docs/jep/index.rst +++ b/docs/jep/index.rst @@ -48,6 +48,7 @@ Then create a pull request that adds a file named 12049: Type Annotation Roadmap for JAX <12049-type-annotations> 14273: `shard_map` (`shmap`) for simple per-device code <14273-shard-map> 15856: `jax.extend`, an extensions module <15856-jex> + 17111: Efficient transposition of `shard_map` (and other maps) <17111-shmap-transpose> Several early JEPs were converted in hindsight from other documentation, diff --git a/jax/_src/core.py b/jax/_src/core.py index 1b6f22d27..20c2001e0 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1084,8 +1084,7 @@ def _why_alive_container_info(container, obj_id) -> str: @contextmanager -def new_main(trace_type: type[Trace], - dynamic: bool = False, +def new_main(trace_type: type[Trace], dynamic: bool = False, **payload) -> Generator[MainTrace, None, None]: # See comments in https://github.com/google/jax/pull/3370 stack = thread_local_state.trace_state.trace_stack @@ -1111,6 +1110,20 @@ def new_main(trace_type: type[Trace], leaked_tracers = maybe_find_leaked_tracers(t()) if leaked_tracers: raise leaked_tracer_error("trace", t(), leaked_tracers) +@contextmanager +def new_dynamic(level: int) -> Generator[None, None, None]: + stack = thread_local_state.trace_state.trace_stack + prev_dynamic, stack.dynamic = stack.dynamic, stack.stack[level] + _update_thread_local_jit_state(stack.dynamic) + try: + yield + finally: + stack.dynamic = prev_dynamic + _update_thread_local_jit_state(stack.dynamic) + +def dynamic_level() -> int: + return thread_local_state.trace_state.trace_stack.dynamic.level + @contextmanager def new_base_main(trace_type: type[Trace], **payload) -> Generator[MainTrace, None, None]: diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index c76acebcb..b1b67b599 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -853,6 +853,7 @@ def _custom_vjp_call_jaxpr_jvp( tangents_out = ad.custom_lin_p.bind( *res, *args_dot, num_res=res_tree.num_leaves, bwd=bwd, out_avals=avals_out, symbolic_zeros=symbolic_zeros) + tangents_out = map(lax.tie_p.bind, primals_out, tangents_out) tangents_out = map(ad.recast_to_float0, primals_out, tangents_out) return primals_out, tangents_out ad.primitive_jvps[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_jvp diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index ed4a0c7f4..ed896dd5c 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -403,6 +403,7 @@ class JVPTrace(Trace): tangents_out = custom_lin_p.bind( *res, *tangents_in, num_res=res_tree.num_leaves, bwd=bwd, out_avals=avals_out, symbolic_zeros=symbolic_zeros) + tangents_out = map(jax._src.lax.lax.tie_p.bind, primals_out, tangents_out) tangents_out = map(recast_to_float0, primals_out, tangents_out) return map(partial(JVPTracer, self), primals_out, tangents_out) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 64621f99c..b6f8da0bc 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -4966,6 +4966,17 @@ def _empty_lower(ctx, *, dtype): mlir.register_lowering(empty_p, _empty_lower) +tie_p = core.Primitive('tie') +tie_p.def_impl(lambda x, y: y) +tie_p.def_abstract_eval(lambda x, y: y) +mlir.register_lowering(tie_p, lambda ctx, x, y: [y]) +ad.primitive_jvps[tie_p] = \ + lambda primals, tangents: (tie_p.bind(*primals), tangents[-1]) +ad.primitive_transposes[tie_p] = lambda ct, x, _: [None, ct] +pe.def_trivial_padding(tie_p) +batching.defvectorized(tie_p) + + class BIntRules: @staticmethod def physical_element_aval(dtype) -> core.ShapedArray: diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index 392065b23..ce4e82632 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -812,6 +812,7 @@ batching.axis_primitive_batchers[psum_p] = \ partial(_batched_reduction_collective, psum_p, lambda v, axis_size: axis_size * v) core.axis_substitution_rules[psum_p] = partial(_subst_all_names_in_param, 'axes') + # We set a special bind rule for psum so that psum(1, 'i') can be evaluated at # tracing time. @psum_p.def_custom_bind diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 9a3e36551..0b5427605 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -1465,6 +1465,8 @@ tf_not_yet_impl = [ "pmin", "ppermute", "psum", + "psum2", + "pbroadcast", "pmax", "pgather", "reduce_scatter", @@ -3449,6 +3451,9 @@ def _reduce_precision(x, *, exponent_bits, mantissa_bits): tf_impl[lax.reduce_precision_p] = _reduce_precision +tf_impl[lax_internal.tie_p] = lambda x, y: y + + def _register_checkpoint_pytrees(): """Registers TF custom container types as pytrees.""" m = tf.Module() diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 17ce1266f..ddf8cabf7 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -18,7 +18,7 @@ import enum from functools import partial import inspect import itertools as it -import math +from math import prod import operator as op from typing import Any, Callable, Optional, TypeVar, Union @@ -104,7 +104,7 @@ def _shard_map(f: Callable, mesh: Mesh, in_specs: Specs, raise e('shard_map in_specs') from None _check_specs_vs_args(f, mesh, in_tree, in_specs, in_specs_flat, args_flat) in_names_flat = tuple(map(_canonicalize_spec, in_specs_flat)) - flat_fun, out_tree = flatten_fun_nokwargs(fun, in_tree) + fun, out_tree = flatten_fun_nokwargs(fun, in_tree) @memoize def out_names_thunk(): @@ -119,10 +119,15 @@ def _shard_map(f: Callable, mesh: Mesh, in_specs: Specs, e, *_ = prefix_errors(out_specs_, dummy) raise e('shard_map out_specs') from None return tuple(map(_canonicalize_spec, out_specs_flat)) + + if rewrite := check_rep: + fun = _efficient_transpose_rewrite(fun, mesh, in_names_flat, out_names_thunk) + try: out_flat = shard_map_p.bind( - flat_fun, *args_flat, mesh=mesh, in_names=in_names_flat, - out_names_thunk=out_names_thunk, check_rep=check_rep, auto=auto) + fun, *args_flat, mesh=mesh, in_names=in_names_flat, + out_names_thunk=out_names_thunk, check_rep=check_rep, rewrite=rewrite, + auto=auto) except _SpecError as e: fails, = e.args if not callable(out_specs): @@ -135,12 +140,12 @@ def _shard_map(f: Callable, mesh: Mesh, in_specs: Specs, except _RepError as e: fails, = e.args if not callable(out_specs): - msg = _rep_error(f, mesh, out_tree(), out_specs, fails) + msg = _inout_rep_error(f, mesh, out_tree(), out_specs, fails) raise ValueError(msg) from None return tree_unflatten(out_tree(), out_flat) return wrapped -# Internally use AxisNames = Dict[int, Tuple[AxisName, ...]], not PartitionSpecs +# Internally use AxisNames = dict[int, tuple[AxisName, ...]], not PartitionSpecs AxisNames = dict[int, tuple[AxisName, ...]] # TODO(mattjj): make it hashable def _canonicalize_spec(spec: PartitionSpec) -> AxisNames: if isinstance(spec, PartitionSpec): @@ -183,7 +188,7 @@ def _check_specs_vs_args( msg = _spec_rank_error(SpecErrorType.input, f, in_tree, in_specs, fail) raise ValueError(msg) in_names_flat = tuple(map(_canonicalize_spec, in_specs_flat)) - fail = [a if any(a.shape[d] % math.prod(mesh.shape[n] for n in ns) + fail = [a if any(a.shape[d] % prod(mesh.shape[n] for n in ns) for d, ns in names.items()) else no_fail for a, names in zip(in_avals, in_names_flat)] if any(f is not no_fail for f in fail): @@ -239,10 +244,10 @@ def _spec_divisibility_error( f"parameter '{list(ba.arguments.keys())[arg_key.idx]}',") names = _canonicalize_spec(spec) for d, ns in names.items(): - if aval.shape[d] % math.prod(mesh.shape[n] for n in ns): + if aval.shape[d] % prod(mesh.shape[n] for n in ns): axis = f"axes {ns}" if len(ns) > 1 else f"axis '{ns[0]}'" total = 'total ' if len(ns) > 1 else '' - sz = math.prod(mesh.shape[n] for n in ns) + sz = prod(mesh.shape[n] for n in ns) msgs.append( f"* args{keystr(fail_key)} of shape {aval.str_short()}{extra} " f"corresponds to in_specs{keystr(spec_key)} of value {spec}, " @@ -263,8 +268,8 @@ def _spec_divisibility_error( f"padding the input and adapting '{fun_name}' appropriately.") return msg -def _rep_error(f: Callable, mesh: Mesh, tree: PyTreeDef, specs: Specs, - fails: list[set | NoFail]) -> str: +def _inout_rep_error(f: Callable, mesh: Mesh, tree: PyTreeDef, specs: Specs, + fails: list[set | NoFail]) -> str: fun_name = getattr(f, '__name__', str(f)) msgs = [] for (spec_key, spec), (fail_key, rep) in _iter_paths(tree, specs, fails): @@ -332,10 +337,11 @@ class ShardMapPrimitive(core.Primitive): def bind(self, fun: lu.WrappedFun, *args: MaybeTracer, mesh: Mesh, in_names: tuple[AxisNames, ...], out_names_thunk: Callable[[], tuple[AxisNames, ...]], - check_rep: bool, auto: frozenset[AxisName]) -> Sequence[MaybeTracer]: + check_rep: bool, rewrite: bool, auto: frozenset[AxisName] + ) -> Sequence[MaybeTracer]: top_trace = core.find_top_trace(args) fun, env_todo = process_env_traces(fun, top_trace.level, mesh, in_names, - out_names_thunk, check_rep, auto) + out_names_thunk, check_rep, rewrite, auto) @as_hashable_function(closure=out_names_thunk) def new_out_names_thunk(): @@ -348,7 +354,8 @@ class ShardMapPrimitive(core.Primitive): tracers = map(top_trace.full_raise, args) outs = top_trace.process_shard_map( # pytype: disable=attribute-error shard_map_p, fun, tracers, mesh=mesh, in_names=in_names, - out_names_thunk=new_out_names_thunk, check_rep=check_rep, auto=auto) + out_names_thunk=new_out_names_thunk, check_rep=check_rep, + rewrite=rewrite, auto=auto) todos, _ = env_todo() return map(core.full_lower, core.apply_todos(todos, outs)) @@ -364,7 +371,7 @@ shard_map_p = ShardMapPrimitive('shard_map') @lu.transformation_with_aux def process_env_traces(level: int, mesh, in_names, out_names_thunk, check_rep, - auto, *args: Any): + rewrite, auto, *args: Any): outs = yield args, {} todos, out_names_transforms = [], [] while True: @@ -377,7 +384,7 @@ def process_env_traces(level: int, mesh, in_names, out_names_thunk, check_rep, trace = ans._trace.main.with_cur_sublevel() outs = map(trace.full_raise, outs) outs, (todo, xform) = trace.post_process_shard_map( - outs, mesh, in_names, out_names_thunk, check_rep, auto) + outs, mesh, in_names, out_names_thunk, check_rep, rewrite, auto) todos.append(todo) out_names_transforms.append(xform) yield outs, (tuple(todos), tuple(out_names_transforms)) @@ -390,19 +397,19 @@ def _shard_map_staging( in_names: tuple[AxisNames, ...], out_names_thunk: Callable[[], tuple[AxisNames, ...]], check_rep: bool, + rewrite: bool, auto: frozenset, ) -> Sequence[pe.DynamicJaxprTracer]: in_avals = [t.aval for t in in_tracers] in_avals_ = map(partial(_shard_aval, mesh), in_names, in_avals) main = trace.main with core.new_sublevel(), core.extend_axis_env_nd(mesh.shape.items()): - jaxpr, out_avals_generic, consts = pe.trace_to_subjaxpr_dynamic( - f, main, in_avals_) - out_avals_ = map(_check_shapedarray, out_avals_generic) + jaxpr, genavals, consts = pe.trace_to_subjaxpr_dynamic(f, main, in_avals_) + out_avals_ = map(_check_shapedarray, genavals) _check_names(out_names_thunk(), out_avals_) + in_rep = map(partial(_in_names_to_rep, mesh), in_names) if check_rep: - in_rep = map(partial(_in_names_to_rep, mesh), in_names) - out_rep = _output_rep(mesh, jaxpr, in_rep) + out_rep = _check_rep(mesh, jaxpr, in_rep) _check_reps(mesh, out_names_thunk(), out_rep) out_avals = map(partial(_unshard_aval, mesh), out_names_thunk(), out_avals_) source_info = source_info_util.current() @@ -415,13 +422,80 @@ def _shard_map_staging( jaxpr = pe.convert_constvars_jaxpr(jaxpr) params = dict(mesh=mesh, in_names=in_names_staged, out_names=tuple(out_names_thunk()), jaxpr=jaxpr, - check_rep=check_rep, auto=auto) + check_rep=check_rep, rewrite=rewrite, auto=auto) eqn = pe.new_jaxpr_eqn([*constvars, *invars], outvars, prim, params, jaxpr.effects, source_info) trace.frame.add_eqn(eqn) return out_tracers pe.DynamicJaxprTrace.process_shard_map = _shard_map_staging + +Val = Any + +# TODO(mattjj): caching +def _replication_rewrite_match( + mesh: Mesh, jaxpr: core.ClosedJaxpr, in_rep: Sequence[set[AxisName]], + out_rep_dst: Sequence[set[AxisName]], + ) -> core.ClosedJaxpr: + f = lu.wrap_init(partial(_rep_rewrite, mesh, jaxpr, in_rep)) + f = _match_rep(f, mesh, out_rep_dst) + with core.extend_axis_env_nd(mesh.shape.items()): + jaxpr_, _, consts = pe.trace_to_jaxpr_dynamic(f, jaxpr.in_avals) + return core.ClosedJaxpr(jaxpr_, consts) + +@lu.transformation +def _match_rep(mesh: Mesh, out_rep_dst: Sequence[set[AxisName]], *args): + out_vals, out_reps = yield args, {} + _check_reps2(mesh, out_rep_dst, out_reps) + out_vals = [pbroadcast(x, tuple(n for n in src if n not in dst)) if src - dst + else x for x, src, dst in zip(out_vals, out_reps, out_rep_dst)] + yield out_vals + + +def _rep_rewrite( + mesh: Mesh, jaxpr_: core.ClosedJaxpr, + in_rep: Sequence[set[AxisName]], *args: Val, + ) -> tuple[tuple[Val], tuple[set[AxisName]]]: + jaxpr, consts = jaxpr_.jaxpr, jaxpr_.consts + + env: dict[core.Var, tuple[Val, set[AxisName]]] = {} + + def read(x: core.Atom) -> tuple[Val, set[AxisName]]: + return env[x] if isinstance(x, core.Var) else (x.val, set(mesh.axis_names)) + + def write(v: core.Var, val: Val, rep: set[AxisName]) -> None: + env[v] = (val, rep) + + map(write, jaxpr.constvars, consts, [set(mesh.axis_names)] * len(consts)) + map(write, jaxpr.invars, args, in_rep) + for e in jaxpr.eqns: + rule = _rewrite_rules.get(e.primitive, partial(_rule_missing, e.primitive)) + in_vals, in_reps = unzip2(map(read, e.invars)) + out_vals, out_reps = rule(mesh, in_reps, *in_vals, **e.params) + map(write, e.outvars, out_vals, out_reps) + out_vals, out_reps = unzip2(map(read, jaxpr.outvars)) + return out_vals, out_reps + +def _rule_missing(prim: core.Primitive, *_, **__): + raise NotImplementedError( + f"No replication rule for {prim}. As a workaround, pass the " + "`check_rep=False` argument to `shard_map`. To get this fixed, open an " + "issue at https://github.com/google/jax/issues") + +def _replication_rewrite_nomatch( + mesh: Mesh, jaxpr: core.ClosedJaxpr, in_rep: Sequence[set[AxisName]], + ) -> tuple[core.ClosedJaxpr, list[set[AxisName]]]: + f = lu.wrap_init(partial(_rep_rewrite, mesh, jaxpr, in_rep)) + f, out_rep = _grab_out_rep(f) + with core.extend_axis_env_nd(mesh.shape.items()): + jaxpr_, _, consts = pe.trace_to_jaxpr_dynamic(f, jaxpr.in_avals) + return core.ClosedJaxpr(jaxpr_, consts), list(out_rep()) + +@lu.transformation_with_aux +def _grab_out_rep(*args): + yield (yield args, {}) + + def _check_shapedarray(aval: core.AbstractValue) -> core.ShapedArray: assert isinstance(aval, core.ShapedArray) return aval @@ -429,7 +503,7 @@ def _check_shapedarray(aval: core.AbstractValue) -> core.ShapedArray: def _shard_aval(mesh: Mesh, names: AxisNames, aval: core.AbstractValue ) -> core.AbstractValue: if isinstance(aval, core.ShapedArray): - return aval.update(tuple(sz // math.prod(mesh.shape[n] for n in names.get(i, ())) + return aval.update(tuple(sz // prod(mesh.shape[n] for n in names.get(i, ())) for i, sz in enumerate(aval.shape))) else: raise NotImplementedError # TODO(mattjj): add table with handlers @@ -437,7 +511,7 @@ def _shard_aval(mesh: Mesh, names: AxisNames, aval: core.AbstractValue def _unshard_aval(mesh: Mesh, names: AxisNames, aval: core.AbstractValue ) -> core.AbstractValue: if isinstance(aval, core.ShapedArray): - return aval.update(tuple(sz * math.prod(mesh.shape[n] for n in names.get(i, ())) + return aval.update(tuple(sz * prod(mesh.shape[n] for n in names.get(i, ())) for i, sz in enumerate(aval.shape)), named_shape={k: v for k, v in aval.named_shape.items() if k not in mesh.shape}) @@ -447,7 +521,7 @@ def _unshard_aval(mesh: Mesh, names: AxisNames, aval: core.AbstractValue # Type-checking def _shard_map_typecheck(_, *in_atoms, jaxpr, mesh, in_names, out_names, - check_rep, auto): + check_rep, rewrite, auto): del auto # TODO(mattjj,parkers): check for v, x, in_name in zip(jaxpr.invars, in_atoms, in_names): if not core.typecompat(v.aval, _shard_aval(mesh, in_name, x.aval)): @@ -457,11 +531,11 @@ def _shard_map_typecheck(_, *in_atoms, jaxpr, mesh, in_names, out_names, core.check_jaxpr(jaxpr) if check_rep: in_rep = map(partial(_in_names_to_rep, mesh), in_names) - out_rep = _output_rep(mesh, jaxpr, in_rep) + out_rep = _check_rep(mesh, jaxpr, in_rep) for rep, dst in zip(out_rep, out_names): if not _valid_repeats(mesh, rep, dst): - raise core.JaxprTypeError("shard_map can't prove output is sufficiently " - "replicated") + raise core.JaxprTypeError("shard_map can't prove output is " + "sufficiently replicated") out_avals_sharded = [x.aval for x in jaxpr.outvars] out_avals = map(partial(_unshard_aval, mesh), out_names, out_avals_sharded) return out_avals, jaxpr.effects @@ -470,7 +544,7 @@ core.custom_typechecks[shard_map_p] = _shard_map_typecheck def _in_names_to_rep(mesh: Mesh, names: AxisNames) -> set[AxisName]: return set(mesh.axis_names) - {n for ns in names.values() for n in ns} -def _output_rep(mesh: Mesh, jaxpr: core.Jaxpr, in_rep: Sequence[set[AxisName]], +def _check_rep(mesh: Mesh, jaxpr: core.Jaxpr, in_rep: Sequence[set[AxisName]], ) -> Sequence[set[AxisName]]: env: dict[core.Var, set[AxisName]] = {} @@ -484,7 +558,7 @@ def _output_rep(mesh: Mesh, jaxpr: core.Jaxpr, in_rep: Sequence[set[AxisName]], map(write, jaxpr.invars, in_rep) last_used = core.last_used(jaxpr) for e in jaxpr.eqns: - rule = _rep_rules.get(e.primitive, partial(_rep_rule, e.primitive)) + rule = _check_rules.get(e.primitive, partial(_rule_missing, e.primitive)) out_rep = rule(mesh, *map(read, e.invars), **e.params) if e.primitive.multiple_results: out_rep = [out_rep] * len(e.outvars) if type(out_rep) is set else out_rep @@ -500,8 +574,8 @@ def _valid_repeats(mesh: Mesh, rep: set[AxisName], dst: AxisNames) -> bool: # Lowering def _shard_map_lowering(ctx, *in_nodes, jaxpr, mesh, in_names, out_names, - check_rep, auto): - del check_rep + check_rep, rewrite, auto): + del check_rep, rewrite in_avals_ = [v.aval for v in jaxpr.invars] out_avals_ = [x.aval for x in jaxpr.outvars] in_nodes_ = map(partial(_xla_shard, ctx, mesh, auto), in_names, ctx.avals_in, @@ -557,7 +631,7 @@ def _pspec_mhlo_attrs(names: AxisNames, aval: core.AbstractValue) -> str: # Eager evaluation def _shard_map_impl(trace, prim, fun, args, *, mesh, in_names, out_names_thunk, - check_rep, auto): + check_rep, rewrite, auto): if auto: raise NotImplementedError del prim, auto args = map(partial(_unmatch_spec, mesh), in_names, args) @@ -572,8 +646,10 @@ def _shard_map_impl(trace, prim, fun, args, *, mesh, in_names, out_names_thunk, del main, t, in_tracers, ans, out_tracers out_avals = [core.mapped_aval(x.shape[0], 0, core.get_aval(x)) for x in outs_] _check_names(out_names_thunk(), out_avals) # pytype: disable=wrong-arg-types - if check_rep: _check_reps(mesh, out_names_thunk(), out_rep) - return map(partial(_match_spec, mesh), out_rep, out_names_thunk(), outs_) + if check_rep: + _check_reps(mesh, out_names_thunk(), out_rep) + return map(partial(_match_spec, mesh, check_rep), out_rep, out_names_thunk(), + outs_) core.EvalTrace.process_shard_map = _shard_map_impl def _names_to_pspec(names: AxisNames) -> PartitionSpec: @@ -587,7 +663,7 @@ def _unmatch_spec(mesh: Mesh, src: AxisNames, x: JaxType) -> JaxType: def _unmatch(mesh, src_tup, x): src = _names_to_pspec(dict(src_tup)) dst = P(mesh.axis_names) - return shard_map(_add_singleton, mesh, (src,), dst)(x) + return shard_map(_add_singleton, mesh, (src,), dst, check_rep=False)(x) def _check_names(names: Sequence[AxisNames], avals: Sequence[core.ShapedArray] ) -> None: @@ -602,14 +678,21 @@ def _check_reps(mesh, names, reps): if any(f is not no_fail for f in fail): raise _RepError(fail) class _RepError(Exception): pass -def _match_spec(mesh: Mesh, rep: set[AxisName], dst: AxisNames, x: JaxType - ) -> JaxType: - with core.eval_context(): - return jax.jit(HashablePartial(_match, mesh, tuple(dst.items())))(x) +def _check_reps2(mesh, reps_dest, reps): + fail = [src if not dst.issubset(src) else no_fail + for dst, src in zip(reps_dest, reps)] + if any(f is not no_fail for f in fail): raise _RepError(fail) -def _match(mesh, dst_tup, x): +def _match_spec(mesh: Mesh, check_rep: bool, + rep: set[AxisName], dst: AxisNames, x: JaxType) -> JaxType: + fn = HashablePartial(_match, mesh, check_rep, tuple(dst.items())) + with core.eval_context(): + return jax.jit(fn)(x) + +def _match(mesh, check_rep, dst_tup, x): src = P(mesh.axis_names) dst = _names_to_pspec(dict(dst_tup)) + # TODO put back (?) needed for rep checking in eager? for now test rewrite return shard_map(_rem_singleton, mesh, (src,), dst, check_rep=False)(x) def _rem_singleton(x): return x.reshape(x.shape[1:]) @@ -640,7 +723,7 @@ class ShardMapTrace(core.Trace): f = HashablePartial(_prim_applier, prim, tuple(params.items()), self.mesh) with core.eval_context(), jax.disable_jit(False): out_vals = jax.jit(f)(*in_vals) - rep_rule = _rep_rules.get(prim, partial(_rep_rule, prim)) + rep_rule = _check_rules.get(prim, partial(_rule_missing, prim)) out_rep = rep_rule(self.mesh, *in_rep, **params) if self.check else set() if prim.multiple_results: out_rep = [out_rep] * len(out_vals) if type(out_rep) is set else out_rep @@ -758,92 +841,312 @@ def _device_put_eager_rule(mesh, x, *, src, device): f"shard_map-decorated functions, but got device {device}") eager_rules[dispatch.device_put_p] = _device_put_eager_rule -# Static replication checking +# New primitives for efficient transposition -def _rep_rule(prim: core.Primitive, mesh: Mesh, *in_rep: set[AxisName], - **params: Any) -> set[AxisName] | list[set[AxisName]]: - raise NotImplementedError( - f"No replication rule for {prim}. As a workaround, pass the " - "`check_rep=False` argument to `shard_map`. To get this fixed, open an " - "issue at https://github.com/google/jax/issues") +# psum2_p is like psum_p except has a different transpose, so mostly copied: +psum2_p = core.AxisPrimitive('psum2') +psum2_p.multiple_results = True +psum2_p.def_impl(lax_parallel.psum_p.impl) +psum2_p.def_effectful_abstract_eval(lax_parallel.psum_p.abstract_eval) +mlir.register_lowering(psum2_p, mlir._lowerings[lax_parallel.psum_p]) +batching.primitive_batchers[psum2_p] = partial(lax_parallel._reduction_batcher, psum2_p) +batching.axis_primitive_batchers[psum2_p] = \ + partial(lax_parallel._batched_reduction_collective, psum2_p, + lambda v, axis_size: axis_size * v) +core.axis_substitution_rules[psum2_p] = \ + partial(lax_parallel._subst_all_names_in_param, 'axes') +def _psum2_transpose_rule(cts, *args, axes, axis_index_groups): + del args + return pbroadcast_p.bind(*cts, axes=axes, axis_index_groups=axis_index_groups) +ad.deflinear2(psum2_p, _psum2_transpose_rule) -_rep_rules: dict[core.Primitive, Callable] = {} -register_rule = lambda prim: lambda rule: _rep_rules.setdefault(prim, rule) -register_standard = lambda prim: _rep_rules.setdefault(prim, _standard_rep_rule) +# pbroadcast_p is exactly the transpose of psum2_p +def pbroadcast(x, axis_name): + axes = (axis_name,) if not isinstance(axis_name, tuple) else axis_name + xs, treedef = tree_flatten(x) + ys = pbroadcast_p.bind(*xs, axes=axes, axis_index_groups=None) + return tree_unflatten(treedef, ys) +pbroadcast_p = core.AxisPrimitive('pbroadcast') +pbroadcast_p.multiple_results = True +pbroadcast_p.def_impl(lambda *args, axes, axis_index_groups: args) +pbroadcast_p.def_abstract_eval(lambda *args, axes, axis_index_groups: args) +mlir.register_lowering(pbroadcast_p, lambda ctx, *x, axes, axis_index_groups: x) +def _pbroadcast_batcher(vals_in, dims_in, *, axes, axis_index_groups): + if any(type(axis) is int for axis in axes): raise NotImplementedError + vals_out = pbroadcast_p.bind(*vals_in, axes=axes, + axis_index_groups=axis_index_groups) + return vals_out, dims_in +batching.primitive_batchers[pbroadcast_p] = _pbroadcast_batcher +def _pbroadcast_axis_batcher(size, name, trace_type, vals_in, dims_in, *, axes, + groups): + raise NotImplementedError # vmap with axis name involved in this primitive +batching.axis_primitive_batchers[pbroadcast_p] = _pbroadcast_axis_batcher +core.axis_substitution_rules[pbroadcast_p] = \ + partial(lax_parallel._subst_all_names_in_param, 'axes') +ad.deflinear2(pbroadcast_p, + lambda cts, *_, axes, axis_index_groups: + psum2_p.bind(*cts, axes=axes, axis_index_groups=axis_index_groups)) + +# Rewrite rules and static replication checking for efficient transposition + +_rewrite_rules: dict[core.Primitive, Callable] = {} +register_rewrite = lambda prim: lambda r: _rewrite_rules.setdefault(prim, r) +register_standard_rewrite = lambda prim: \ + _rewrite_rules.setdefault(prim, partial(_standard_rewrite_rule, prim)) +register_norewrite = lambda p: \ + _rewrite_rules.setdefault(p, partial(_no_rewrite, p, _check_rules[p])) + +_check_rules: dict[core.Primitive, Callable] = {} +register_check = lambda prim: lambda rule: _check_rules.setdefault(prim, rule) +register_standard_check = \ + lambda prim: _check_rules.setdefault(prim, partial(_standard_check, prim)) + +def _no_rewrite(prim, rule, mesh, in_rep, *args, **params): + out_vals = prim.bind(*args,**params) + out_rep = rule(mesh, *in_rep, **params) + if prim.multiple_results: + out_rep_ = out_rep if type(out_rep) is list else [out_rep] * len(out_vals) + else: + out_vals, out_rep_ = [out_vals], [out_rep] + return out_vals, out_rep_ + +def _standard_rewrite_rule(prim, mesh, in_rep, *args, **params): + # The standard rewrite inserts pbroadcasts but doesn't change the primitive. + out_rep_ = set.intersection(*in_rep) if in_rep else set(mesh.axis_names) + args_ = [pbroadcast(x, tuple(n for n in src if n not in out_rep_)) + if src - out_rep_ else x for x, src in zip(args, in_rep)] + out_vals_ = prim.bind(*args_, **params) + out_rep = [out_rep_] * len(out_vals_) if prim.multiple_results else [out_rep_] + out_vals = [out_vals_] if not prim.multiple_results else out_vals_ + return out_vals, out_rep + +def _standard_check(prim, mesh, *in_rep, **__): + # The standard check require args' and outputs' replications to be the same. + if in_rep and not in_rep[:-1] == in_rep[1:]: + raise Exception(f"Primitive {prim} requires argument replication types " + f"to match, but got {in_rep}. Please open an issue at " + "https://github.com/google/jax/issues") + return in_rep[0] if in_rep else set(mesh.axis_names) + +def register_standard_collective(prim): + register_check(prim)(partial(_standard_collective_check, prim)) + register_rewrite(prim)(partial(_standard_collective_rewrite, prim)) + +def _standard_collective_check(prim, mesh, x_rep, *, axis_name, **params): + # The standard collective check is varying -> varying over axis_name. + del mesh, params + if axis_name in x_rep: + raise Exception(f"Collective {prim} must be applied to a device-varying " + f"replication type, but got {x_rep} for collective acting " + f"over axis name {axis_name}. Please open an issue at " + "https://github.com/google/jax/issues") + return x_rep + +def _standard_collective_rewrite(prim, mesh, in_rep, x, axis_name, **params): + # The standard collective rewrite may insert a pbroadcast on the input. + if type(axis_name) is tuple: raise NotImplementedError # TODO + if params.get('axis_index_groups') is not None: raise NotImplementedError + x_rep, = in_rep + if axis_name in in_rep: + x = pbroadcast(x, (axis_name,)) + out_val = prim.bind(x, axis_name=axis_name, **params) + return [out_val], [x_rep - {axis_name}] -def _standard_rep_rule(_, *in_rep, **__): - return set.intersection(*in_rep) if in_rep else set() for o in it.chain(lax.__dict__.values(), slicing.__dict__.values(), windowed_reductions.__dict__.values(), fft.__dict__.values(), linalg.__dict__.values(), ops.__dict__.values(), - ad_util.__dict__.values(), prng.__dict__.values(), - custom_derivatives.__dict__.values()): - if isinstance(o, core.Primitive): register_standard(o) + ad_util.__dict__.values(), prng.__dict__.values()): + if isinstance(o, core.Primitive): + register_standard_check(o) + register_standard_rewrite(o) -register_standard(lax_parallel.ppermute_p) # doesn't change replication -@register_rule(lax_parallel.psum_p) -def _psum_rule(_, *in_rep, axes, axis_index_groups): +@register_check(lax_parallel.psum_p) +def _psum_check(_, *in_rep, axes, axis_index_groups): + assert False # should be rewritten away + +@register_rewrite(lax_parallel.psum_p) +def _psum_rewrite(_, in_rep, *args, axes, axis_index_groups): + # Replace the psum with psum2, insert pbroadcasts on input, replicated output. if axis_index_groups is not None: raise NotImplementedError axes = (axes,) if not isinstance(axes, tuple) else axes - return [r | set(axes) for r in in_rep] # introduces replication + out_rep = [r | set(axes) for r in in_rep] # TODO determinism (and elsewhere) + args_ = [pbroadcast(x, tuple(n for n in src if n not in dst)) + if src - dst else x for x, src, dst in zip(args, in_rep, out_rep)] + out_val = psum2_p.bind(*args_, axes=axes, axis_index_groups=axis_index_groups) + return out_val, out_rep -@register_rule(lax_parallel.all_gather_p) -def _all_gather_rule(_, in_rep, *, all_gather_dimension, axis_name, axis_size, - axis_index_groups, tiled): - if axis_index_groups is not None: raise NotImplementedError - if not tiled: raise NotImplementedError - axis_name = (axis_name,) if not isinstance(axis_name, tuple) else axis_name - return in_rep | set(axis_name) # introduces replication -@register_rule(lax_parallel.reduce_scatter_p) -def _reduce_scatter_rule(_, in_rep, *, scatter_dimension, axis_name, axis_size, - axis_index_groups, tiled): - if axis_index_groups is not None: raise NotImplementedError - if not tiled: raise NotImplementedError - return in_rep - {axis_name} # removes replication +@register_check(psum2_p) +def _psum2_check(_, *in_rep, axes, axis_index_groups): + assert type(axes) is tuple + if any(set(axes) & r for r in in_rep): + raise Exception("Collective psum must be applied to a device-varying " + f"replication type, but got {in_rep} for collective acting " + f"over axis name {axes}. Please open an issue at " + "https://github.com/google/jax/issues") + return [r | set(axes) for r in in_rep] +register_norewrite(psum2_p) -@register_rule(lax_parallel.all_to_all_p) -def _all_to_all_rule(_, in_rep, *, split_axis, concat_axis, axis_name, - axis_index_groups): - if axis_index_groups is not None: raise NotImplementedError - return in_rep - {axis_name} # removes replication -@register_rule(lax_parallel.axis_index_p) -def _axis_index_rule(mesh, *, axis_name): - axis_name = (axis_name,) if not isinstance(axis_name, tuple) else axis_name +@register_check(pbroadcast_p) +def _pbroadcast_check(_, *in_rep, axes, axis_index_groups): + assert type(axes) is tuple + if not all(set(axes) & r for r in in_rep): + raise Exception("Collective pbroadcast must be applied to a " + "non-device-varying " + f"replication type, but got {in_rep} for collective acting " + f"over axis name {axes}. Please open an issue at " + "https://github.com/google/jax/issues") + return [r - set(axes) for r in in_rep] +register_norewrite(pbroadcast_p) + + +register_standard_collective(lax_parallel.all_gather_p) +register_standard_collective(lax_parallel.all_to_all_p) +register_standard_collective(lax_parallel.ppermute_p) +register_standard_collective(lax_parallel.reduce_scatter_p) + + +@register_check(lax_parallel.axis_index_p) +def _axis_index_check(mesh, *, axis_name): + axis_name = (axis_name,) if not type(axis_name) is tuple else axis_name return set(mesh.shape) - set(axis_name) +register_norewrite(lax_parallel.axis_index_p) -@register_rule(pjit.pjit_p) -def _pjit_rule(mesh, *in_rep, jaxpr, **kwargs): - return _output_rep(mesh, jaxpr.jaxpr, in_rep) -@register_rule(debugging.debug_callback_p) +@register_rewrite(pjit.pjit_p) +def _pjit_rewrite(mesh, in_rep, *args, jaxpr, **kwargs): + jaxpr_, out_rep = _replication_rewrite_nomatch(mesh, jaxpr, in_rep) + out_vals = pjit.pjit_p.bind(*args, jaxpr=jaxpr_, **kwargs) + return out_vals, out_rep + +@register_check(pjit.pjit_p) +def _pjit_check(mesh, *in_rep, jaxpr, **kwargs): + return _check_rep(mesh, jaxpr.jaxpr, in_rep) + + +@register_check(core.call_p) +def _core_call_check(mesh, *in_rep, call_jaxpr, **kwargs): + return _check_rep(mesh, call_jaxpr, in_rep) + + +@register_check(debugging.debug_callback_p) def _debug_callback_rule(mesh, *in_rep, **_): return [] +register_norewrite(debugging.debug_callback_p) -@register_rule(callback.pure_callback_p) -def _pure_callback_rule(mesh, *in_rep, result_avals, **_): + +@register_check(callback.pure_callback_p) +def _pure_callback_rule(mesh, *_, result_avals, **__): return [set()] * len(result_avals) +register_norewrite(callback.pure_callback_p) -@register_rule(dispatch.device_put_p) -def _device_put_rep_rule(mesh, x, *, src, device): + +@register_check(dispatch.device_put_p) +def _device_put_rule(mesh, x, **_): return x +register_norewrite(dispatch.device_put_p) -@register_rule(control_flow.loops.scan_p) -def _scan_rule(mesh, *in_rep, jaxpr, num_consts, num_carry, linear, length, - reverse, unroll): - const_rep, carry_rep, xs_rep = split_list(in_rep, [num_consts, num_carry]) + +@register_check(ad.custom_lin_p) +def _custom_lin_rule(mesh, *_, out_avals, **__): + return [set()] * len(out_avals) +register_norewrite(ad.custom_lin_p) + + +@register_check(control_flow.loops.scan_p) +def _scan_check(mesh, *in_rep, jaxpr, num_consts, num_carry, **_): + _, carry_rep_in, _ = split_list(in_rep, [num_consts, num_carry]) + out_rep = _check_rep(mesh, jaxpr.jaxpr, in_rep) + carry_rep_out, _ = split_list(out_rep, [num_carry]) + if not carry_rep_in == carry_rep_out: + raise Exception("Scan carry input and output got mismatched replication " + f"types {carry_rep_in} and {carry_rep_out}. Please open an " + "issue at https://github.com/google/jax/issues") + return out_rep + +@register_rewrite(control_flow.loops.scan_p) +def _scan_rewrite(mesh, in_rep, *args, jaxpr, num_consts, num_carry, **params): + const_rep, carry_rep_in, xs_rep = split_list(in_rep, [num_consts, num_carry]) for _ in range(1 + num_carry): - out_rep = _output_rep(mesh, jaxpr.jaxpr, [*const_rep, *carry_rep, *xs_rep]) - if carry_rep == out_rep[:num_carry]: + in_rep_ = [*const_rep, *carry_rep_in, *xs_rep] + _, out_rep = _replication_rewrite_nomatch(mesh, jaxpr, in_rep_) + carry_rep_out, ys_rep = split_list(out_rep, [num_carry]) + carry_rep_out = map(op.and_, carry_rep_in, carry_rep_out) + if carry_rep_in == carry_rep_out: break else: - carry_rep = map(op.and_, carry_rep, out_rep[:num_carry]) + carry_rep_in = carry_rep_out else: assert False, 'Fixpoint not reached' - return out_rep + + args = [pbroadcast(x, tuple(n for n in src if n not in dst)) + if src - dst else x for x, src, dst in zip(args, in_rep, in_rep_)] + out_rep = [*carry_rep_out, *ys_rep] + jaxpr_ = _replication_rewrite_match(mesh, jaxpr, in_rep_, out_rep) + + out_vals = control_flow.loops.scan_p.bind( + *args, jaxpr=jaxpr_, num_consts=num_consts, num_carry=num_carry, **params) + return out_vals, out_rep + + +@register_rewrite(core.closed_call_p) +def _closed_call_rewrite(mesh, in_rep, *args, call_jaxpr, **kwargs): + new_jaxpr, out_rep = _replication_rewrite_nomatch(mesh, call_jaxpr, in_rep) + out_vals = core.closed_call_p.bind(*args, jaxpr=new_jaxpr, **kwargs) + return out_vals, out_rep + +@register_check(core.closed_call_p) +def _closed_call_check(mesh, *in_rep, call_jaxpr, **kwargs): + return _check_rep(mesh, call_jaxpr.jaxpr, in_rep) + + +@register_check(custom_derivatives.custom_jvp_call_p) +def _custom_jvp_call_check(mesh, *in_rep, call_jaxpr, jvp_jaxpr_thunk, + num_consts, symbolic_zeros): + return _check_rep(mesh, call_jaxpr.jaxpr, in_rep) + +@register_rewrite(custom_derivatives.custom_vjp_call_jaxpr_p) +def _custom_vjp_call_jaxpr_rewrite( + mesh, in_rep, *args, fun_jaxpr, fwd_jaxpr_thunk, bwd, num_consts, out_trees, + symbolic_zeros): + if symbolic_zeros: + msg = "Please open an issue at https://github.com/google/jax/issues !" + raise NotImplementedError(msg) + + fun_jaxpr_, out_rep = _replication_rewrite_nomatch(mesh, fun_jaxpr, in_rep) + _, in_rep_ = split_list(in_rep, [num_consts]) + out_rep2 = [] + + @pe._memoize + def fwd_jaxpr_thunk_(*zeros): + fwd_jaxpr = core.ClosedJaxpr(*fwd_jaxpr_thunk(*zeros)) + fwd_jaxpr_, out_rep = _replication_rewrite_nomatch(mesh, fwd_jaxpr, in_rep_) + out_rep2.append(out_rep) + return fwd_jaxpr_.jaxpr, fwd_jaxpr_.consts + + bwd_ = _rewrite_bwd(bwd, mesh, lambda: out_rep2[0], in_rep_) + + outs = custom_derivatives.custom_vjp_call_jaxpr_p.bind( + *args, fun_jaxpr=fun_jaxpr_, fwd_jaxpr_thunk=fwd_jaxpr_thunk_, bwd=bwd_, + num_consts=num_consts, out_trees=out_trees, symbolic_zeros=symbolic_zeros) + out_rep = out_rep2[0] if out_rep2 else out_rep + return outs, out_rep + +@register_check(custom_derivatives.custom_vjp_call_jaxpr_p) +def _custom_vjp_call_jaxpr_check(mesh, *in_rep, fun_jaxpr, **_): + return _check_rep(mesh, fun_jaxpr.jaxpr, in_rep) + + +del _check_rules[lax.tie_p] + +@register_check(lax.tie_p) +def _tie_check(mesh, x_rep, y_rep): + return x_rep +register_norewrite(lax.tie_p) + # Batching @@ -853,12 +1156,13 @@ def _shard_map_batch( in_names: tuple[AxisNames, ...], out_names_thunk: Callable[[], tuple[AxisNames, ...]], check_rep: bool, + rewrite: bool, auto: frozenset) -> Sequence[batching.BatchTracer]: in_vals, in_dims = unzip2((t.val, t.batch_dim) for t in in_tracers) if all(bdim is batching.not_mapped for bdim in in_dims): return prim.bind(fun, *in_vals, mesh=mesh, in_names=in_names, out_names_thunk=out_names_thunk, check_rep=check_rep, - auto=auto) + rewrite=rewrite, auto=auto) if any(isinstance(d, batching.RaggedAxis) for d in in_dims): raise NotImplementedError fun, out_dims = batching.batch_subtrace(fun, trace.main, tuple(in_dims)) @@ -874,7 +1178,7 @@ def _shard_map_batch( new_params = dict(mesh=mesh, in_names=new_in_names, out_names_thunk=new_out_names_thunk, check_rep=check_rep, - auto=auto) + rewrite=rewrite, auto=auto) out_vals = prim.bind(fun, *in_vals, **new_params) make_tracer = partial(batching.BatchTracer, trace, source_info=source_info_util.current()) @@ -882,8 +1186,8 @@ def _shard_map_batch( batching.BatchTrace.process_shard_map = _shard_map_batch def _shard_map_batch_post_process(trace, out_tracers, mesh, in_names, - out_names_thunk, check_rep, auto): - del mesh, in_names, out_names_thunk, check_rep, auto + out_names_thunk, check_rep, rewrite, auto): + del mesh, in_names, out_names_thunk, check_rep, rewrite, auto vals, dims, srcs = unzip3((t.val, t.batch_dim, t.source_info) for t in out_tracers) m = trace.main @@ -906,7 +1210,7 @@ def _batch_out_names(spmd_axis_name, dims, out_names): # Autodiff def _shard_map_jvp(trace, shard_map_p, f, tracers, mesh, in_names, - out_names_thunk, check_rep, auto): + out_names_thunk, check_rep, rewrite, auto): primals, tangents = unzip2((t.primal, t.tangent) for t in tracers) which_nz = [ type(t) is not ad.Zero for t in tangents] tangents = [t if type(t) is not ad.Zero else None for t in tangents] @@ -921,7 +1225,7 @@ def _shard_map_jvp(trace, shard_map_p, f, tracers, mesh, in_names, return (*out_ax, *(ax for ax, nz in zip(out_ax, which_nz_out()) if nz)) params = dict(mesh=mesh, in_names=(*in_names, *tangent_in_names), out_names_thunk=new_out_names_thunk, check_rep=check_rep, - auto=auto) + rewrite=rewrite, auto=auto) f_jvp, out_tree = ad.traceable(f_jvp, in_tree) result = shard_map_p.bind(f_jvp, *args, **params) primal_out, tangent_out = tree_unflatten(out_tree(), result) @@ -931,8 +1235,8 @@ def _shard_map_jvp(trace, shard_map_p, f, tracers, mesh, in_names, ad.JVPTrace.process_shard_map = _shard_map_jvp def _shard_map_jvp_post_process(trace, out_tracers, mesh, in_names, - out_names_thunk, check_rep, auto): - del mesh, in_names, out_names_thunk, check_rep, auto + out_names_thunk, check_rep, rewrite, auto): + del mesh, in_names, out_names_thunk, check_rep, rewrite, auto primals, tangents = unzip2((t.primal, t.tangent) for t in out_tracers) out, treedef = tree_flatten((primals, tangents)) tangents_nz = [type(t) is not ad.Zero for t in tangents] @@ -946,7 +1250,7 @@ def _shard_map_jvp_post_process(trace, out_tracers, mesh, in_names, ad.JVPTrace.post_process_shard_map = _shard_map_jvp_post_process def _shard_map_partial_eval(trace, shard_map_p, f, tracers, mesh, in_names, - out_names_thunk, check_rep, auto): + out_names_thunk, check_rep, rewrite, auto): in_pvals = [t.pval for t in tracers] in_knowns, in_avals, in_consts = pe.partition_pvals(in_pvals) unk_in_names, known_in_names = pe.partition_list(in_knowns, in_names) @@ -966,7 +1270,7 @@ def _shard_map_partial_eval(trace, shard_map_p, f, tracers, mesh, in_names, known_params = dict(mesh=mesh, in_names=(*known_in_names,), out_names_thunk=known_out_names, check_rep=check_rep, - auto=auto) + rewrite=rewrite, auto=auto) out = shard_map_p.bind(f_known, *in_consts, **known_params) out_knowns, out_avals_sharded, jaxpr, env = aux() out_consts, res = split_list(out, [len(out) - len(jaxpr.constvars)]) @@ -980,7 +1284,7 @@ def _shard_map_partial_eval(trace, shard_map_p, f, tracers, mesh, in_names, unk_arg_tracers = [t for t in tracers if not t.is_known()] unk_params = dict(mesh=mesh, in_names=unk_in_names, out_names=unk_out_names, jaxpr=jaxpr, check_rep=False, - auto=auto) + rewrite=rewrite, auto=auto) out_avals = map(partial(_unshard_aval, mesh), unk_out_names, out_avals_sharded) out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(a), None) for a in out_avals] @@ -992,7 +1296,7 @@ def _shard_map_partial_eval(trace, shard_map_p, f, tracers, mesh, in_names, pe.JaxprTrace.process_shard_map = _shard_map_partial_eval def _shard_map_partial_eval_post_process( - trace, tracers, mesh, in_names, out_names_thunk, check_rep, auto): + trace, tracers, mesh, in_names, out_names_thunk, check_rep, rewrite, auto): del check_rep unk_tracers = [t for t in tracers if not t.is_known()] jaxpr, res, env = pe.tracers_to_jaxpr([], unk_tracers) @@ -1013,7 +1317,7 @@ def _shard_map_partial_eval_post_process( staged_in_names = ({0: (*mesh.axis_names,)},) * len(res) + ({},) * len(env) staged_params = dict(jaxpr=jaxpr_, mesh=mesh, in_names=staged_in_names, out_names=(*out_names_unknown,), check_rep=False, - auto=auto) + rewrite=rewrite, auto=auto) out_avals = map(partial(_unshard_aval, mesh), out_names_unknown, out_avals_) out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(a), None) @@ -1053,10 +1357,11 @@ def _promote_scalar_residuals_jaxpr(jaxpr, res): return jaxpr, res def _shard_map_transpose(out_cts, *args, jaxpr, mesh, in_names, out_names, - check_rep, auto): + check_rep, rewrite, auto): mb_div = lambda x, y: x / y if y != 1 else x out_cts = [ad.Zero(_shard_aval(mesh, ns, x.aval)) if type(x) is ad.Zero - else mb_div(x, math.prod(map(mesh.shape.get, _unmentioned(mesh, ns)))) + else x if rewrite + else mb_div(x, prod(map(mesh.shape.get, _unmentioned(mesh, ns)))) for ns, x in zip(out_names, out_cts)] args = [x if type(x) is not ad.UndefinedPrimal else ad.UndefinedPrimal(_shard_aval(mesh, ns, x.aval)) @@ -1072,9 +1377,10 @@ def _shard_map_transpose(out_cts, *args, jaxpr, mesh, in_names, out_names, out = ad.backward_pass( jaxpr_unknown.jaxpr, (), False, (), (*res_reshaped, *undefs), out_cts ) - return [ad.Zero(_unshard_aval(mesh, ns, x.aval)) if type(x) is ad.Zero - else jax.lax.psum(x, tuple(_unmentioned(mesh, ns))) - for ns, x in zip(in_names, out)] + out = [ad.Zero(_unshard_aval(mesh, ns, x.aval)) if type(x) is ad.Zero + else x if rewrite else jax.lax.psum(x, tuple(_unmentioned(mesh, ns))) + for ns, x in zip(in_names, out)] + return out fun_trans, nz_arg_cts = ad.nonzero_outputs(fun_trans) fun_trans_flat, out_tree = flatten_fun_nokwargs(fun_trans, in_tree) @@ -1088,7 +1394,8 @@ def _shard_map_transpose(out_cts, *args, jaxpr, mesh, in_names, out_names, out_flat = shard_map_p.bind( fun_trans_flat, *all_args, mesh=mesh, in_names=tuple(new_in_names), - out_names_thunk=new_out_names_thunk, check_rep=check_rep, auto=auto) + out_names_thunk=new_out_names_thunk, check_rep=check_rep, rewrite=rewrite, + auto=auto) return tree_unflatten(out_tree(), out_flat) ad.primitive_transposes[shard_map_p] = _shard_map_transpose @@ -1201,7 +1508,6 @@ def _shard_map_dce(used_outputs: list[bool], eqn: core.JaxprEqn return used_inputs, new_eqn pe.dce_rules[shard_map_p] = _shard_map_dce - # Implementing pmap in terms of shard_map def pmap(f, axis_name=None, *, in_axes=0, out_axes=0, @@ -1277,3 +1583,148 @@ def _get_devices(p, backend): if jax.process_count() > 1: return devs[:p.global_axis_size] return devs[:p.local_axis_size] + + +### Rewrite! + +class RewriteTracer(core.Tracer): + rep: set[AxisName] + val: Val + + def __init__(self, trace, rep, val): + self._trace = trace + self.rep = rep + self.val = val + + @property + def aval(self) -> core.AbstractValue: + return core.get_aval(self.val) + + def full_lower(self) -> RewriteTracer: + return self + + def __str__(self) -> str: + return str(self.val) # TODO(mattjj): could show replication info here + +class RewriteTrace(core.Trace): + mesh: Mesh + dyna: int + + def __init__(self, *args, mesh, dyna): + super().__init__(*args) + self.mesh = mesh + self.dyna = dyna + + def pure(self, val) -> RewriteTracer: + return RewriteTracer(self, set(self.mesh.axis_names), val) + + def lift(self, tracer: core.Tracer) -> RewriteTracer: + return RewriteTracer(self, set(self.mesh.axis_names), tracer) + + def sublift(self, tracer: core.Tracer) -> RewriteTracer: + return RewriteTracer(self, tracer.rep, tracer.val) + + def process_primitive(self, prim, in_tracers, params): + rule = _rewrite_rules.get(prim, partial(_rule_missing, prim)) + in_vals, in_reps = unzip2((t.val, t.rep) for t in in_tracers) + with core.new_dynamic(self.dyna): + out_vals, out_reps = rule(self.mesh, in_reps, *in_vals, **params) + out_tracers = map(partial(RewriteTracer, self), out_reps, out_vals) + return out_tracers if prim.multiple_results else out_tracers[0] + + def process_call(self, call_primitive, f, in_tracers, params): + in_vals, in_reps = unzip2((t.val, t.rep) for t in in_tracers) + f, out_reps = _rewrite_subtrace(f, self.main, tuple(in_reps)) + with core.new_dynamic(self.dyna): + out_vals = call_primitive.bind(f, *in_vals, **params) + return map(partial(RewriteTracer, self), out_reps(), out_vals) + + def post_process_call(self, call_primitive, out_tracers, params): + assert False # unreachable + + def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): + if symbolic_zeros: + msg = "Please open an issue at https://github.com/google/jax/issues !" + raise NotImplementedError(msg) + in_vals, in_reps = unzip2((t.val, t.rep) for t in tracers) + fun, out_reps1 = _rewrite_subtrace(fun, self.main, in_reps) + jvp, out_reps2 = _rewrite_subtrace(jvp, self.main, in_reps * 2) + with core.new_dynamic(self.dyna): + out_vals = prim.bind(fun, jvp, *in_vals, symbolic_zeros=symbolic_zeros) + fst, out_reps = lu.merge_linear_aux(out_reps1, out_reps2) + if not fst: + assert out_reps == out_reps[:len(out_reps) // 2] * 2 + out_reps = out_reps[:len(out_reps) // 2] + return map(partial(RewriteTracer, self), out_reps, out_vals) + + def post_process_custom_jvp_call(self, out_tracers, jvp_was_run): + assert False # unreachable + + def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees, + symbolic_zeros): + if symbolic_zeros: + msg = "Please open an issue at https://github.com/google/jax/issues !" + raise NotImplementedError(msg) + in_vals, in_reps = unzip2((t.val, t.rep) for t in tracers) + fun, out_reps1 = _rewrite_subtrace(fun, self.main, in_reps) + fwd_in_reps = [r_ for r in in_reps for r_ in [r, set(self.mesh.axis_names)]] + fwd, out_reps2 = _rewrite_subtrace(fwd, self.main, fwd_in_reps) + bwd = _rewrite_bwd(bwd, self.mesh, out_reps2, in_reps) + with core.new_dynamic(self.dyna): + out_vals = prim.bind(fun, fwd, bwd, *in_vals, out_trees=out_trees, + symbolic_zeros=symbolic_zeros) + fst, out_reps = lu.merge_linear_aux(out_reps1, out_reps2) + if not fst: + _, res_tree = out_trees() + _, out_reps = split_list(out_reps, [res_tree.num_leaves]) + return map(partial(RewriteTracer, self), out_reps, out_vals) + + def post_process_custom_vjp_call(self, out_tracers, _): + assert False # unreachable + + # TODO process_axis_index + +@lu.transformation +def _efficient_transpose_rewrite(mesh, in_names, out_names_thunk, *args): + in_reps = map(partial(_in_names_to_rep, mesh), in_names) + lvl = core.dynamic_level() + with core.new_main(RewriteTrace, dynamic=True, mesh=mesh, dyna=lvl) as main: + t = main.with_cur_sublevel() + in_tracers = map(partial(RewriteTracer, t), in_reps, args) + ans = yield in_tracers, {} + out_tracers = map(t.full_raise, ans) + out_vals, out_reps = unzip2((t.val, t.rep) for t in out_tracers) + del main, t, in_tracers, out_tracers, ans + out_rep_dst = [frozenset(_unmentioned(mesh, n)) for n in out_names_thunk()] + out_vals = [pbroadcast(x, tuple(n for n in src if n not in dst)) if src - dst + else x for x, src, dst in zip(out_vals, out_reps, out_rep_dst)] + yield out_vals + +@lu.transformation_with_aux +def _rewrite_subtrace(main, in_reps, *in_vals): + assert len(in_reps) == len(in_vals), (len(in_reps), len(in_vals)) + t = main.with_cur_sublevel() + in_tracers = map(partial(RewriteTracer, t), in_reps, in_vals) + with core.new_dynamic(main.level): + outs = yield in_tracers, {} + out_tracers = map(t.full_raise, outs) + out_vals, out_reps = unzip2((t.val, t.rep) for t in out_tracers) + yield out_vals, out_reps + +def _rewrite_bwd(bwd, mesh, in_reps, reps_dst): + def new_bwd(*args): + lvl = core.dynamic_level() + with core.new_main(RewriteTrace, dynamic=True, mesh=mesh, dyna=lvl) as main: + bwd_, reps_thunk = _rewrite_subtrace(lu.wrap_init(bwd), main, in_reps()) + out = bwd_.call_wrapped(*args) + del main + return map(_match_replication, reps_thunk(), reps_dst, out) + return new_bwd + +def _match_replication(src, dst, x): + if dst - src: + x, = psum2_p.bind(x, axes=tuple(n for n in dst if n not in src), + axis_index_groups=None) + if src - dst: + x = pbroadcast(x, tuple(n for n in src if n not in dst)) + return x diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 4fd177fcb..406892467 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -37,6 +37,7 @@ from jax._src import xla_bridge from jax._src.util import safe_zip, safe_map, partition_list, merge_lists from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe +from jax._src import linear_util as lu from jax._src import tree_util import jax.numpy as jnp @@ -113,9 +114,12 @@ class ShardMapTest(jtu.JaxTestCase): mesh, a, _ = create_inputs(P('z', ('x', 'y')), P(None, None)) assert a.device_buffers[0].shape == (4, 2) + # NOTE(mattjj): to use out_specs=P(None, ('x', 'y')), we need to use + # all_gather_invariant primitive, which differs in its output replication + # type compared to all_gather. @jax.jit @partial(shard_map, mesh=mesh, - in_specs=(P('z', ('x', 'y')),), out_specs=P(None, ('x', 'y'))) + in_specs=(P('z', ('x', 'y')),), out_specs=P('z', ('x', 'y'))) def fwd(a): return lax.all_gather(a, 'z', axis=0, tiled=True) @@ -559,8 +563,7 @@ class ShardMapTest(jtu.JaxTestCase): def test_partial_eval_custom_axis_env(self): mesh = Mesh(jax.devices(), ('i',)) - @partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i'), - check_rep=False) # check_rep=False b/c no scan rep rule yet + @partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i')) def f(_): _, idx = jax.lax.scan(lambda _, __: (None, jax.lax.axis_index('i')), None, None, length=1) @@ -861,6 +864,274 @@ class ShardMapTest(jtu.JaxTestCase): self.assertIn("\"[('i',)]\"", mhlo_str) self.assertIn("-> (tensor<1xf32> {jax.result_info = \"[('i',)]\"})", mhlo_str) + def test_rewrite_process_call(self): + def f(x): + return core.call_p.bind(lu.wrap_init(lambda x: [2. * x]), x)[0] * x + + mesh = jtu.create_global_mesh((4,), ('x',)) + g = shard_map(f, mesh, in_specs=(P('x'),), out_specs=P('x')) + x = jnp.arange(4.) + y = jax.jit(g)(x) # eager requires shmap to have ShardMapTrace.process_call + self.assertAllClose(y, 2 * x * x, check_dtypes=True) + + def test_rewrite_post_process_call(self): + # We shouldn't hit post_process_call here because of RewriteTrace's dynamic + # behavior (i.e. no data dependence). + mesh = jtu.create_global_mesh((4,), ('x',)) + + @jax.jit + @partial(shard_map, mesh=mesh, in_specs=(P('x'),), out_specs=P('x')) + def f(x): + return core.call_p.bind(lu.wrap_init(lambda: [2. * x]))[0] * x + + x = jnp.arange(4.) + y = f(x) + self.assertAllClose(y, 2 * x * x, check_dtypes=True) + + def test_rewrite_process_custom_jvp_call(self): + @jax.custom_jvp + def foo(x): + return 2. * x + + @foo.defjvp + def foo_jvp(primals, tangents): + (x,), (x_dot,) = primals, tangents + return foo(x), 2. * x_dot + + mesh = jtu.create_global_mesh((4,), ('x',)) + g = shard_map(lambda x: foo(x) * x, mesh, + in_specs=(P('x'),), out_specs=P('x')) + x = jnp.arange(4.) + + y = jax.jit(g)(x) + self.assertAllClose(y, 2 * x * x, check_dtypes=True) + + y2, y_dot = jax.jvp(jax.jit(g), (x,), (3 * x,)) + self.assertAllClose(y2, 2 * x * x, check_dtypes=True) + self.assertAllClose(y_dot, 2 * 2 * 3 * x * x, check_dtypes=True) + + def test_rewrite_process_custom_vjp_call(self): + @jax.custom_vjp + def foo(x): + return 2. * x + + def foo_fwd(x): + return foo(x), None + + def foo_bwd(_, y_bar): + return 2. * y_bar, + + foo.defvjp(foo_fwd, foo_bwd) + + mesh = jtu.create_global_mesh((4,), ('x',)) + g = shard_map(lambda x: foo(x) * x, mesh, + in_specs=(P('x'),), out_specs=P('x')) + + x = jnp.arange(4.) + y = jax.jit(g)(x) + self.assertAllClose(y, 2 * x * x, check_dtypes=True) + + y_, x_bar = jax.value_and_grad(lambda x: jax.jit(g)(x).sum())(x) + self.assertAllClose(y_, (2 * x * x).sum(), check_dtypes=True) + self.assertAllClose(x_bar, 2 * 2 * x, check_dtypes=True) + + def test_rewrite_process_custom_vjp_call_match_more_replicated(self): + @jax.custom_vjp + def foo(x): + return 2. * x + + def foo_fwd(x): + return foo(x), None + + def foo_bwd(_, y_bar): + return jnp.ones_like(y_bar), # diff! more replicated than primal/tangent + + foo.defvjp(foo_fwd, foo_bwd) + + mesh = jtu.create_global_mesh((4,), ('x',)) + g = shard_map(lambda x: foo(x) * x, mesh, + in_specs=(P('x'),), out_specs=P('x')) + x = jnp.arange(4.) + + y = jax.jit(g)(x) + self.assertAllClose(y, 2 * x * x, check_dtypes=True) + + y_, x_bar = jax.value_and_grad(lambda x: jax.jit(g)(x).sum())(x) + self.assertAllClose(y_, (2 * x * x).sum(), check_dtypes=True) + self.assertAllClose(x_bar, jnp.ones_like(x) + 2 * x, check_dtypes=True) + + def test_rewrite_process_custom_vjp_call_match_less_replicated(self): + @jax.custom_vjp + def foo(x, y): + del y + return 2. * x + + def foo_fwd(x, y): + return foo(x, y), y + + def foo_bwd(y, _): + return y, None # diff! x_bar less replicated than primal/tangent + + foo.defvjp(foo_fwd, foo_bwd) + + mesh = jtu.create_global_mesh((4,), ('x',)) + g = shard_map(lambda x, y: foo(x, y) * y, mesh, + in_specs=(P(), P('x')), out_specs=P('x')) + x = jnp.arange(4.) + y = jnp.arange(4 * 4.) + + z = jax.jit(g)(x, y) + self.assertAllClose(z, 2 * jnp.tile(x, (4,)) * y, check_dtypes=False) + + z_, x_bar = jax.value_and_grad(lambda x, y: jax.jit(g)(x, y).sum())(x, y) + self.assertAllClose(z.sum(), z_, check_dtypes=False) + self.assertAllClose(x_bar, jnp.arange(16).reshape(4, 4).sum(0), + check_dtypes=False) + + def test_rewrite_custom_vjp_call_jaxpr(self): + @jax.custom_vjp + def foo(x): + return 2. * x + + def foo_fwd(x): + return foo(x), None + + def foo_bwd(_, y_bar): + return 2. * y_bar, + + foo.defvjp(foo_fwd, foo_bwd) + + def foo_scan(x): + y, _ = jax.lax.scan(lambda x, _: (foo(x), None), x, None, length=1) + return y + + mesh = jtu.create_global_mesh((4,), ('x',)) + g = shard_map(lambda x: foo_scan(x) * x, mesh, + in_specs=(P('x'),), out_specs=P('x')) + + x = jnp.arange(4.) + y = jax.jit(g)(x) + self.assertAllClose(y, 2 * x * x, check_dtypes=True) + + y_, x_bar = jax.value_and_grad(lambda x: jax.jit(g)(x).sum())(x) + self.assertAllClose(y_, (2 * x * x).sum(), check_dtypes=True) + self.assertAllClose(x_bar, 2 * 2 * x, check_dtypes=True) + + def test_transpose_identity(self): + mesh = jtu.create_global_mesh((4,), ('x',)) + + @partial(shard_map, mesh=mesh, in_specs=P(), out_specs=P()) + def f(x): + return x + + jaxpr = jax.make_jaxpr(jax.vjp(f, 1.)[1])(1.) + e, = jaxpr.jaxpr.eqns + self.assertEmpty(e.params['jaxpr'].eqns) + + jaxpr = jax.make_jaxpr(jax.vjp(jax.vjp(f, 1.)[1], 1.)[1])((1.,)) + e, = jaxpr.jaxpr.eqns + self.assertEmpty(e.params['jaxpr'].eqns) + + @partial(shard_map, mesh=mesh, in_specs=P(), out_specs=P()) + def g(x): + return jax.jit(lambda x: x)(x) + + jaxpr = jax.make_jaxpr(jax.vjp(g, 1.)[1])(1.) + e, = jaxpr.jaxpr.eqns + e1, e2 = e.params['jaxpr'].eqns + self.assertEmpty(e1.outvars) + self.assertEmpty(e2.params['jaxpr'].eqns) + + def test_fanout_specs_transpose_to_psum(self): + mesh = jtu.create_global_mesh((4,), ('x',)) + + @partial(shard_map, mesh=mesh, in_specs=P(), out_specs=P('x')) + def f(x): + return x + + jaxpr = jax.make_jaxpr(jax.vjp(f, jnp.arange(1.))[1])(jnp.arange(4.)) + e, = jaxpr.jaxpr.eqns + e2, = e.params['jaxpr'].eqns + self.assertEqual(str(e2.primitive), 'psum2') + self.assertEqual(e2.params['axes'], ('x',)) + + def test_fanin_psum_transposes_to_fanout(self): + mesh = jtu.create_global_mesh((4,), ('x',)) + + @partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P()) + def f(x): + return jax.lax.psum(x, 'x') + + jaxpr = jax.make_jaxpr(jax.vjp(f, jnp.arange(4.))[1])(jnp.array([1.])) + e, = jaxpr.jaxpr.eqns + e1, = e.params['jaxpr'].eqns + self.assertEqual(str(e1.primitive), 'pbroadcast') + + def test_psum_with_implicit_fanout_self_transposes(self): + mesh = jtu.create_global_mesh((4,), ('x',)) + + @partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P('x')) + def f(x): + return jax.lax.psum(x, 'x') + + jaxpr = jax.make_jaxpr(jax.vjp(f, jnp.arange(4.))[1])(jnp.arange(4.)) + e, = jaxpr.jaxpr.eqns + e1, e2 = e.params['jaxpr'].eqns + self.assertEqual(str(e1.primitive), 'psum2') + self.assertEqual(str(e2.primitive), 'pbroadcast') + + def test_rewrite_binops(self): + mesh = jtu.create_global_mesh((4,), ('x',)) + + @partial(shard_map, mesh=mesh, in_specs=(P(), P('x')), out_specs=P('x')) + def f(x, y): + return x * y + + jaxpr = jax.make_jaxpr(f)(jnp.arange(1.), jnp.arange(4.)) + e, = jaxpr.jaxpr.eqns + e = e.params['jaxpr'].eqns[0] + self.assertEqual(e.primitive.name, 'pbroadcast') + self.assertEqual(e.params['axes'], ('x',)) + + def test_rewrite_scan(self): + mesh = jtu.create_global_mesh((4,), ('x',)) + + @partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P('x')) + def f(x): + x, _ = jax.lax.scan(lambda x, _: (jax.lax.psum(x, 'x'), None), x, None, + length=2) + return x + + jaxpr = jax.make_jaxpr(f)(jnp.arange(4.)) + e, = jaxpr.jaxpr.eqns + e, = e.params['jaxpr'].eqns + e1, e2 = e.params['jaxpr'].eqns + self.assertEqual(e1.primitive.name, 'psum2') + self.assertEqual(e2.primitive.name, 'pbroadcast') + + def test_check_rep_false_grads(self): + # This test is redundant with the systematic tests below, but it serves as a + # direct regression test for a bug. + mesh = jtu.create_global_mesh((4,), ('heads',)) + + def f(q, k, v): + + def body(q, k, v): + return q * k[None, :] + v[None, :] + + out = shard_map(body, mesh, check_rep=False, + in_specs=(q_spec, kv_spec, kv_spec,), + out_specs=q_spec)(q, k, v) + return out.sum() + + q_spec = P('heads', None) + kv_spec = P(None) + q = jax.device_put(jnp.arange(32.).reshape(4, 8), jax.sharding.NamedSharding(mesh, q_spec)) + k = jax.device_put(jnp.arange(8.), jax.sharding.NamedSharding(mesh, kv_spec)) + v = jax.device_put(jnp.arange(8.), jax.sharding.NamedSharding(mesh, kv_spec)) + + jtu.check_grads(f, (q, k, v), order=1, modes=['rev'], rtol=1e-2) + class FunSpec(NamedTuple): name: str @@ -1126,12 +1397,15 @@ class ShardMapSystematicTest(jtu.JaxTestCase): self.assertAllClose(expected, out, check_dtypes=False) @parameterized.named_parameters( - sample(config.FLAGS.jax_num_generated_cases, sample_shmap)) + (name + f'_check_rep={check_rep}', *params, check_rep) + for (name, *params) in sample(config.FLAGS.jax_num_generated_cases, sample_shmap) + for check_rep in [True, False] + ) @jax.default_matmul_precision("float32") - def test_grads(self, fun, mesh, jit, in_specs, out_specs, args, _): + def test_grads(self, fun, mesh, jit, in_specs, out_specs, args, _, check_rep): mesh = self.make_mesh(mesh) args = map(jnp.array, args) - f = shard_map(fun, mesh, in_specs, out_specs) + f = shard_map(fun, mesh, in_specs, out_specs, check_rep=check_rep) if jit: f = jax.jit(f) jtu.check_grads(f, args, order=2, atol=1e-2, rtol=1e-2)