rocm_jax/jax/interpreters/batching.py

474 lines
20 KiB
Python
Raw Normal View History

2018-11-17 18:03:33 -08:00
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from typing import Any, Callable, Dict, Optional, Tuple, Union
2018-11-17 18:03:33 -08:00
import jax
from ..config import config
2018-11-17 18:03:33 -08:00
from .. import core
from ..core import ShapedArray, raise_to_shaped, Trace, Tracer
from ..ad_util import add_jaxvals, add_jaxvals_p, zeros_like_jaxval, zeros_like_p
from .. import linear_util as lu
from ..util import (unzip2, partial, safe_map, wrap_name, split_list,
Add support for non-zero (but still not-None) out_axes in pmap Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`), but its semantics would match the specification of `out_axes=0` (i.e. all outputs should be stacked along the first axis). This patch makes it possible to specify non-zero values for out_axes, but more importantly it lays down the groundwork for `xmap` which will have to use some extremely similar (if not the same) code paths. One thing to note is that when I started this implementation I was also planning to add support for `out_axes=None`, which would allow us to stop using the `unbroadcast` hack, and most of the code is written with that in mind. Unfortunately it turned out that the correct implementation of the transpose rule for maps that do allow unmapped outputs would require me to pretty much simulate what avals-with-names is supposed to achieve. Technically replicated outputs should work today, for as long as the user does not do reverse-mode AD of `pmap`. But I decided that it's better to just disable them altogether until we can get the full and correct behavior. * Implementation details * This patch is significantly more involved than the one that implemented general `in_axes` support. That previous one at least had the foundation of `mapped_invars` which already behaved pretty similarly to general `in_axes`. From a quick glance one might think that `out_axes` should behave similarly to `in_axes`, but it turns out that this is not the case, at least not if we're interested in keeping those primitives final-style. ** Thunking ** The biggest difficulty with handling `out_axes` in final style primitives is that we want to treat them as a prefix of the output pytree, but we don't know the structure of the output pytree until the user function is evaluated! And the user function is not evaluated until we've applied all transforms and reached the impl rule! The solution to this problem is "straightforward": instead of putting `out_axes` as a primitive parameter, we bundle an `out_axes_thunk` which can only be called successfully after the wrapped function has been executed. The thunk returns a list of flat `out_axes`, expanded to the output pytree. However, the thunking presents us with two problems: *** Transformations *** Each transformation that modifies the number of outputs needs to ensure that the thunk is updated to reflect the new values. To make things worse a lot of the transforms can learn the number of added outputs _only after the wrapped function is evaluated_, which leads to the following "time travel" pattern that can be found in most `Trace`s: ```py @lu.transformation_with_aux def compute_output_statistic(*args, **kwargs): outputs = yield args, kwargs yield outputs, compute_statistic(outputs) wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun) def new_out_axes_thunk(): old_out_axes = params['out_axes_thunk']() return compute_new_out_axes(old_out_axes(), output_statistic()) primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk)) ``` The reason why we have to structure the code this way is that we can only specify a new `out_axes_thunk` before we bind the primitive, but we need the outputs of bind to know how to update the `out_axes_thunk`. To make things worse, the implementation of `bind` is allowed to make a call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_. This means that we cannot compute the output statistic in the implementation of the transformation, but we have to use an extra `lu.transformation_with_aux` for that (this populates the statistic store immediately after `wrapped_fun` is evaluated). The `compute_statistic` function depends on the transform in question. E.g. in the JVP trace it counts the number of non-zero tangent results. The situation is of course further complicated when we take `post_process_map` into account. The new `process_env_traces` now always sets up this funny time travel trampoline just in case it ends up being necessary, and `post_process_map` is now expected to return `(outputs, (todo, out_axes_transform))` instead of just `(outputs, todo)`. *** Compilation cache *** Because the `out_axes_thunk`s are now arguments to a _global_ compilation cache (in the form of `lu.cache` decorator on `parallel_callable`), we have to ensure that they implement `hash` and `==`. This is what forces us to add some slightly weird helpers such as `_hashable_function` and `_ignore_elem_list`. The code that uses those makes an assumption that the output pytree depends deterministically on the identity of the wrapped function, which I think is in line with general JAX assumptions. Otherwise the cache would depend on the identity of the thunk, which changes with every function invocation. Relaxing the global constraint on the cache (e.g. allowing each `pmap(f)` instance to have a separate cache) would make this easier too. * Why final style? * Now, making the primitives initial-style would remove the necessity for thunking, because we could have obtained the output pytree right when the function is wrapped. I assumed there is a good argument for making `pmap` pretend that it's a final-style primitive, but I'm not sure why that is? I hope it's something better than just avoiding a single jaxpr tracing.
2020-11-09 17:23:16 +00:00
canonicalize_axis, moveaxis, as_hashable_function)
from . import xla
2019-05-15 07:25:03 -07:00
from . import partial_eval as pe
2018-11-17 18:03:33 -08:00
map = safe_map
def batch(fun: lu.WrappedFun, in_vals, in_dims, out_dim_dests, axis_name):
# executes a batched version of `fun` following out_dim_dests
batched_fun = batch_fun(fun, in_dims, out_dim_dests, axis_name=axis_name)
return batched_fun.call_wrapped(*in_vals)
2018-11-17 18:03:33 -08:00
@lu.transformation_with_aux
def batch_subtrace(main, in_dims, *in_vals, **params):
trace = main.with_cur_sublevel()
in_tracers = [BatchTracer(trace, val, dim) if dim is not None else val
for val, dim in zip(in_vals, in_dims)]
outs = yield in_tracers, params
out_tracers = map(trace.full_raise, outs)
out_vals, out_dims = unzip2((t.val, t.batch_dim) for t in out_tracers)
yield out_vals, out_dims
2018-11-17 18:03:33 -08:00
def batch_fun(fun : lu.WrappedFun, in_dims, out_dim_dests, axis_name,
sum_match=False):
# transformation version of batch, which doesn't call the function
fun, out_dims = batch_subtrace(fun)
return _batch_fun(fun, axis_name, sum_match, in_dims, out_dims, out_dim_dests)
@lu.transformation
def _batch_fun(axis_name, sum_match, in_dims, out_dims_thunk, out_dim_dests,
*in_vals, **params):
in_dims = in_dims() if callable(in_dims) else in_dims
in_dims = [
canonicalize_axis(dim, np.ndim(val)) if isinstance(dim, int) else dim
for val, dim in zip(in_vals, in_dims)]
axis_size, = {x.shape[d] for x, d in zip(in_vals, in_dims) if d is not not_mapped}
with core.new_main(BatchTrace, axis_name=axis_name) as main:
with core.extend_axis_env(axis_name, axis_size, main):
out_vals = yield (main, in_dims,) + in_vals, params
del main
out_dim_dests = out_dim_dests() if callable(out_dim_dests) else out_dim_dests
out_dims = out_dims_thunk()
for od, od_dest in zip(out_dims, out_dim_dests):
if od is not None and not isinstance(od_dest, int) and not sum_match:
msg = f"vmap has mapped output but out_axes is {od_dest}"
raise ValueError(msg)
out_vals = map(partial(matchaxis, axis_size, sum_match=sum_match),
out_dims, out_dim_dests, out_vals)
yield out_vals
def batch_fun2(fun : lu.WrappedFun, in_dims):
# like `batch_fun` but returns output batch dims (so no out_dim_dests)
fun, out_dims = batch_subtrace(fun)
return _batch_fun2(fun, in_dims), out_dims
@lu.transformation
def _batch_fun2(in_dims, *in_vals, **params):
with core.new_main(BatchTrace, axis_name=None) as main:
out_vals = yield (main, in_dims,) + in_vals, params
del main
yield out_vals
2018-11-17 18:03:33 -08:00
### tracer
# TODO(mattjj): use a special sentinel type rather than None
NotMapped = type(None)
not_mapped = None
2018-11-17 18:03:33 -08:00
class BatchTracer(Tracer):
2019-01-16 16:51:54 +00:00
__slots__ = ['val', 'batch_dim']
def __init__(self, trace, val, batch_dim: Optional[int]):
assert core.skip_checks or type(batch_dim) in (int, NotMapped) # type: ignore
self._trace = trace
2018-11-17 18:03:33 -08:00
self.val = val
self.batch_dim = batch_dim
@property
def aval(self):
aval = raise_to_shaped(core.get_aval(self.val))
if self.batch_dim is not_mapped:
return aval
2018-11-17 18:03:33 -08:00
else:
if aval is core.abstract_unit:
return aval
elif type(aval) is ShapedArray:
assert 0 <= self.batch_dim < aval.ndim
new_shape = tuple(np.delete(aval.shape, self.batch_dim))
return ShapedArray(new_shape, aval.dtype)
2019-06-23 15:31:13 -07:00
else:
raise TypeError(aval)
2018-11-17 18:03:33 -08:00
def full_lower(self):
if self.batch_dim is not_mapped:
2018-11-17 18:03:33 -08:00
return core.full_lower(self.val)
else:
return self
class BatchTrace(Trace):
def __init__(self, *args, axis_name):
super().__init__(*args)
self.axis_name = axis_name
2018-11-17 18:03:33 -08:00
def pure(self, val):
return BatchTracer(self, val, not_mapped)
2018-11-17 18:03:33 -08:00
def lift(self, val):
return BatchTracer(self, val, not_mapped)
2018-11-17 18:03:33 -08:00
def sublift(self, val):
return BatchTracer(self, val.val, val.batch_dim)
def process_primitive(self, primitive, tracers, params):
vals_in, dims_in = unzip2((t.val, t.batch_dim) for t in tracers)
if all(bdim is not_mapped for bdim in dims_in):
2018-11-17 18:03:33 -08:00
return primitive.bind(*vals_in, **params)
if (primitive in collective_rules and
_main_trace_for_axis_names(self.main, params['axis_name'])):
frame = core.axis_frame(self.axis_name)
val_out, dim_out = collective_rules[primitive](frame, vals_in, dims_in, **params)
else:
batched_primitive = get_primitive_batcher(primitive, self.axis_name)
val_out, dim_out = batched_primitive(vals_in, dims_in, **params)
if primitive.multiple_results:
return map(partial(BatchTracer, self), val_out, dim_out)
2018-11-17 18:03:33 -08:00
else:
return BatchTracer(self, val_out, dim_out)
2018-11-17 18:03:33 -08:00
def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params):
assert call_primitive.multiple_results
params = dict(params, name=wrap_name(params.get('name', f.__name__), 'vmap'))
vals, dims = unzip2((t.val, t.batch_dim) for t in tracers)
if all(bdim is not_mapped for bdim in dims):
return call_primitive.bind(f, *vals, **params)
2018-11-17 18:03:33 -08:00
else:
f, dims_out = batch_subtrace(f, self.main, dims)
vals_out = call_primitive.bind(f, *vals, **params)
return [BatchTracer(self, v, d) for v, d in zip(vals_out, dims_out())]
def post_process_call(self, call_primitive, out_tracers, params):
vals, dims = unzip2((t.val, t.batch_dim) for t in out_tracers)
main = self.main
def todo(vals):
trace = main.with_cur_sublevel()
return map(partial(BatchTracer, trace), vals, dims)
return vals, todo
2018-11-17 18:03:33 -08:00
def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params):
vals, dims = unzip2((t.val, t.batch_dim) for t in tracers)
if all(dim is not_mapped for dim in dims):
return map_primitive.bind(f, *vals, **params)
else:
assert len({x.shape[d] for x, d in zip(vals, dims) if d is not not_mapped}) == 1
# The logic for the dimension math below is as follows:
# ╔═════════════╦════════════════════════════════════════╦═══════════╗
# ║ d / in_axis ║ None ║ int ║
# ╠═════════════╬════════════════════════════════════════╩═══════════╣
# ║ None ║ No extra axis, so in_axis unaffected ║
# ╠═════════════╬════════════════════════════════════════╦═══════════╣
# ║ int ║ Not mapped, so batching dim unaffected ║ See below ║
# ╚═════════════╩════════════════════════════════════════╩═══════════╝
# When both d and in_axis are defined then:
# - If `d <= in_axis`, we have to move the `in_axis` one dimension further;
# - If `d > in_axis`, we have to decrement `d` (as `in_axis` will get removed).
Add support for non-zero (but still not-None) out_axes in pmap Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`), but its semantics would match the specification of `out_axes=0` (i.e. all outputs should be stacked along the first axis). This patch makes it possible to specify non-zero values for out_axes, but more importantly it lays down the groundwork for `xmap` which will have to use some extremely similar (if not the same) code paths. One thing to note is that when I started this implementation I was also planning to add support for `out_axes=None`, which would allow us to stop using the `unbroadcast` hack, and most of the code is written with that in mind. Unfortunately it turned out that the correct implementation of the transpose rule for maps that do allow unmapped outputs would require me to pretty much simulate what avals-with-names is supposed to achieve. Technically replicated outputs should work today, for as long as the user does not do reverse-mode AD of `pmap`. But I decided that it's better to just disable them altogether until we can get the full and correct behavior. * Implementation details * This patch is significantly more involved than the one that implemented general `in_axes` support. That previous one at least had the foundation of `mapped_invars` which already behaved pretty similarly to general `in_axes`. From a quick glance one might think that `out_axes` should behave similarly to `in_axes`, but it turns out that this is not the case, at least not if we're interested in keeping those primitives final-style. ** Thunking ** The biggest difficulty with handling `out_axes` in final style primitives is that we want to treat them as a prefix of the output pytree, but we don't know the structure of the output pytree until the user function is evaluated! And the user function is not evaluated until we've applied all transforms and reached the impl rule! The solution to this problem is "straightforward": instead of putting `out_axes` as a primitive parameter, we bundle an `out_axes_thunk` which can only be called successfully after the wrapped function has been executed. The thunk returns a list of flat `out_axes`, expanded to the output pytree. However, the thunking presents us with two problems: *** Transformations *** Each transformation that modifies the number of outputs needs to ensure that the thunk is updated to reflect the new values. To make things worse a lot of the transforms can learn the number of added outputs _only after the wrapped function is evaluated_, which leads to the following "time travel" pattern that can be found in most `Trace`s: ```py @lu.transformation_with_aux def compute_output_statistic(*args, **kwargs): outputs = yield args, kwargs yield outputs, compute_statistic(outputs) wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun) def new_out_axes_thunk(): old_out_axes = params['out_axes_thunk']() return compute_new_out_axes(old_out_axes(), output_statistic()) primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk)) ``` The reason why we have to structure the code this way is that we can only specify a new `out_axes_thunk` before we bind the primitive, but we need the outputs of bind to know how to update the `out_axes_thunk`. To make things worse, the implementation of `bind` is allowed to make a call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_. This means that we cannot compute the output statistic in the implementation of the transformation, but we have to use an extra `lu.transformation_with_aux` for that (this populates the statistic store immediately after `wrapped_fun` is evaluated). The `compute_statistic` function depends on the transform in question. E.g. in the JVP trace it counts the number of non-zero tangent results. The situation is of course further complicated when we take `post_process_map` into account. The new `process_env_traces` now always sets up this funny time travel trampoline just in case it ends up being necessary, and `post_process_map` is now expected to return `(outputs, (todo, out_axes_transform))` instead of just `(outputs, todo)`. *** Compilation cache *** Because the `out_axes_thunk`s are now arguments to a _global_ compilation cache (in the form of `lu.cache` decorator on `parallel_callable`), we have to ensure that they implement `hash` and `==`. This is what forces us to add some slightly weird helpers such as `_hashable_function` and `_ignore_elem_list`. The code that uses those makes an assumption that the output pytree depends deterministically on the identity of the wrapped function, which I think is in line with general JAX assumptions. Otherwise the cache would depend on the identity of the thunk, which changes with every function invocation. Relaxing the global constraint on the cache (e.g. allowing each `pmap(f)` instance to have a separate cache) would make this easier too. * Why final style? * Now, making the primitives initial-style would remove the necessity for thunking, because we could have obtained the output pytree right when the function is wrapped. I assumed there is a good argument for making `pmap` pretend that it's a final-style primitive, but I'm not sure why that is? I hope it's something better than just avoiding a single jaxpr tracing.
2020-11-09 17:23:16 +00:00
def both_mapped(in_out_axis, d):
return in_out_axis is not None and d is not not_mapped
new_in_axes = tuple(
in_axis + 1 if both_mapped(in_axis, d) and d <= in_axis else in_axis
for d, in_axis in zip(dims, params['in_axes']))
new_dims = tuple(
d - 1 if both_mapped(in_axis, d) and in_axis < d else d
for d, in_axis in zip(dims, params['in_axes']))
f, dims_out = batch_subtrace(f, self.main, new_dims)
Add support for non-zero (but still not-None) out_axes in pmap Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`), but its semantics would match the specification of `out_axes=0` (i.e. all outputs should be stacked along the first axis). This patch makes it possible to specify non-zero values for out_axes, but more importantly it lays down the groundwork for `xmap` which will have to use some extremely similar (if not the same) code paths. One thing to note is that when I started this implementation I was also planning to add support for `out_axes=None`, which would allow us to stop using the `unbroadcast` hack, and most of the code is written with that in mind. Unfortunately it turned out that the correct implementation of the transpose rule for maps that do allow unmapped outputs would require me to pretty much simulate what avals-with-names is supposed to achieve. Technically replicated outputs should work today, for as long as the user does not do reverse-mode AD of `pmap`. But I decided that it's better to just disable them altogether until we can get the full and correct behavior. * Implementation details * This patch is significantly more involved than the one that implemented general `in_axes` support. That previous one at least had the foundation of `mapped_invars` which already behaved pretty similarly to general `in_axes`. From a quick glance one might think that `out_axes` should behave similarly to `in_axes`, but it turns out that this is not the case, at least not if we're interested in keeping those primitives final-style. ** Thunking ** The biggest difficulty with handling `out_axes` in final style primitives is that we want to treat them as a prefix of the output pytree, but we don't know the structure of the output pytree until the user function is evaluated! And the user function is not evaluated until we've applied all transforms and reached the impl rule! The solution to this problem is "straightforward": instead of putting `out_axes` as a primitive parameter, we bundle an `out_axes_thunk` which can only be called successfully after the wrapped function has been executed. The thunk returns a list of flat `out_axes`, expanded to the output pytree. However, the thunking presents us with two problems: *** Transformations *** Each transformation that modifies the number of outputs needs to ensure that the thunk is updated to reflect the new values. To make things worse a lot of the transforms can learn the number of added outputs _only after the wrapped function is evaluated_, which leads to the following "time travel" pattern that can be found in most `Trace`s: ```py @lu.transformation_with_aux def compute_output_statistic(*args, **kwargs): outputs = yield args, kwargs yield outputs, compute_statistic(outputs) wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun) def new_out_axes_thunk(): old_out_axes = params['out_axes_thunk']() return compute_new_out_axes(old_out_axes(), output_statistic()) primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk)) ``` The reason why we have to structure the code this way is that we can only specify a new `out_axes_thunk` before we bind the primitive, but we need the outputs of bind to know how to update the `out_axes_thunk`. To make things worse, the implementation of `bind` is allowed to make a call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_. This means that we cannot compute the output statistic in the implementation of the transformation, but we have to use an extra `lu.transformation_with_aux` for that (this populates the statistic store immediately after `wrapped_fun` is evaluated). The `compute_statistic` function depends on the transform in question. E.g. in the JVP trace it counts the number of non-zero tangent results. The situation is of course further complicated when we take `post_process_map` into account. The new `process_env_traces` now always sets up this funny time travel trampoline just in case it ends up being necessary, and `post_process_map` is now expected to return `(outputs, (todo, out_axes_transform))` instead of just `(outputs, todo)`. *** Compilation cache *** Because the `out_axes_thunk`s are now arguments to a _global_ compilation cache (in the form of `lu.cache` decorator on `parallel_callable`), we have to ensure that they implement `hash` and `==`. This is what forces us to add some slightly weird helpers such as `_hashable_function` and `_ignore_elem_list`. The code that uses those makes an assumption that the output pytree depends deterministically on the identity of the wrapped function, which I think is in line with general JAX assumptions. Otherwise the cache would depend on the identity of the thunk, which changes with every function invocation. Relaxing the global constraint on the cache (e.g. allowing each `pmap(f)` instance to have a separate cache) would make this easier too. * Why final style? * Now, making the primitives initial-style would remove the necessity for thunking, because we could have obtained the output pytree right when the function is wrapped. I assumed there is a good argument for making `pmap` pretend that it's a final-style primitive, but I'm not sure why that is? I hope it's something better than just avoiding a single jaxpr tracing.
2020-11-09 17:23:16 +00:00
out_axes_thunk = params['out_axes_thunk']
@as_hashable_function(key=out_axes_thunk)
def new_out_axes_thunk():
return tuple(out_axis + 1 if both_mapped(out_axis, d) and d < out_axis else out_axis
for out_axis, d in zip(out_axes_thunk(), dims_out()))
new_params = dict(params, in_axes=new_in_axes, out_axes_thunk=new_out_axes_thunk)
vals_out = map_primitive.bind(f, *vals, **new_params)
dims_out = (d + 1 if both_mapped(out_axis, d) and out_axis <= d else d
for d, out_axis in zip(dims_out(), out_axes_thunk()))
return [BatchTracer(self, v, d) for v, d in zip(vals_out, dims_out)]
def post_process_map(self, call_primitive, out_tracers, params):
vals, dims = unzip2((t.val, t.batch_dim) for t in out_tracers)
main = self.main
Add support for non-zero (but still not-None) out_axes in pmap Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`), but its semantics would match the specification of `out_axes=0` (i.e. all outputs should be stacked along the first axis). This patch makes it possible to specify non-zero values for out_axes, but more importantly it lays down the groundwork for `xmap` which will have to use some extremely similar (if not the same) code paths. One thing to note is that when I started this implementation I was also planning to add support for `out_axes=None`, which would allow us to stop using the `unbroadcast` hack, and most of the code is written with that in mind. Unfortunately it turned out that the correct implementation of the transpose rule for maps that do allow unmapped outputs would require me to pretty much simulate what avals-with-names is supposed to achieve. Technically replicated outputs should work today, for as long as the user does not do reverse-mode AD of `pmap`. But I decided that it's better to just disable them altogether until we can get the full and correct behavior. * Implementation details * This patch is significantly more involved than the one that implemented general `in_axes` support. That previous one at least had the foundation of `mapped_invars` which already behaved pretty similarly to general `in_axes`. From a quick glance one might think that `out_axes` should behave similarly to `in_axes`, but it turns out that this is not the case, at least not if we're interested in keeping those primitives final-style. ** Thunking ** The biggest difficulty with handling `out_axes` in final style primitives is that we want to treat them as a prefix of the output pytree, but we don't know the structure of the output pytree until the user function is evaluated! And the user function is not evaluated until we've applied all transforms and reached the impl rule! The solution to this problem is "straightforward": instead of putting `out_axes` as a primitive parameter, we bundle an `out_axes_thunk` which can only be called successfully after the wrapped function has been executed. The thunk returns a list of flat `out_axes`, expanded to the output pytree. However, the thunking presents us with two problems: *** Transformations *** Each transformation that modifies the number of outputs needs to ensure that the thunk is updated to reflect the new values. To make things worse a lot of the transforms can learn the number of added outputs _only after the wrapped function is evaluated_, which leads to the following "time travel" pattern that can be found in most `Trace`s: ```py @lu.transformation_with_aux def compute_output_statistic(*args, **kwargs): outputs = yield args, kwargs yield outputs, compute_statistic(outputs) wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun) def new_out_axes_thunk(): old_out_axes = params['out_axes_thunk']() return compute_new_out_axes(old_out_axes(), output_statistic()) primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk)) ``` The reason why we have to structure the code this way is that we can only specify a new `out_axes_thunk` before we bind the primitive, but we need the outputs of bind to know how to update the `out_axes_thunk`. To make things worse, the implementation of `bind` is allowed to make a call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_. This means that we cannot compute the output statistic in the implementation of the transformation, but we have to use an extra `lu.transformation_with_aux` for that (this populates the statistic store immediately after `wrapped_fun` is evaluated). The `compute_statistic` function depends on the transform in question. E.g. in the JVP trace it counts the number of non-zero tangent results. The situation is of course further complicated when we take `post_process_map` into account. The new `process_env_traces` now always sets up this funny time travel trampoline just in case it ends up being necessary, and `post_process_map` is now expected to return `(outputs, (todo, out_axes_transform))` instead of just `(outputs, todo)`. *** Compilation cache *** Because the `out_axes_thunk`s are now arguments to a _global_ compilation cache (in the form of `lu.cache` decorator on `parallel_callable`), we have to ensure that they implement `hash` and `==`. This is what forces us to add some slightly weird helpers such as `_hashable_function` and `_ignore_elem_list`. The code that uses those makes an assumption that the output pytree depends deterministically on the identity of the wrapped function, which I think is in line with general JAX assumptions. Otherwise the cache would depend on the identity of the thunk, which changes with every function invocation. Relaxing the global constraint on the cache (e.g. allowing each `pmap(f)` instance to have a separate cache) would make this easier too. * Why final style? * Now, making the primitives initial-style would remove the necessity for thunking, because we could have obtained the output pytree right when the function is wrapped. I assumed there is a good argument for making `pmap` pretend that it's a final-style primitive, but I'm not sure why that is? I hope it's something better than just avoiding a single jaxpr tracing.
2020-11-09 17:23:16 +00:00
def both_mapped(in_out_axis, d):
return in_out_axis is not None and d is not not_mapped
def todo(vals):
trace = main.with_cur_sublevel()
Add support for non-zero (but still not-None) out_axes in pmap Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`), but its semantics would match the specification of `out_axes=0` (i.e. all outputs should be stacked along the first axis). This patch makes it possible to specify non-zero values for out_axes, but more importantly it lays down the groundwork for `xmap` which will have to use some extremely similar (if not the same) code paths. One thing to note is that when I started this implementation I was also planning to add support for `out_axes=None`, which would allow us to stop using the `unbroadcast` hack, and most of the code is written with that in mind. Unfortunately it turned out that the correct implementation of the transpose rule for maps that do allow unmapped outputs would require me to pretty much simulate what avals-with-names is supposed to achieve. Technically replicated outputs should work today, for as long as the user does not do reverse-mode AD of `pmap`. But I decided that it's better to just disable them altogether until we can get the full and correct behavior. * Implementation details * This patch is significantly more involved than the one that implemented general `in_axes` support. That previous one at least had the foundation of `mapped_invars` which already behaved pretty similarly to general `in_axes`. From a quick glance one might think that `out_axes` should behave similarly to `in_axes`, but it turns out that this is not the case, at least not if we're interested in keeping those primitives final-style. ** Thunking ** The biggest difficulty with handling `out_axes` in final style primitives is that we want to treat them as a prefix of the output pytree, but we don't know the structure of the output pytree until the user function is evaluated! And the user function is not evaluated until we've applied all transforms and reached the impl rule! The solution to this problem is "straightforward": instead of putting `out_axes` as a primitive parameter, we bundle an `out_axes_thunk` which can only be called successfully after the wrapped function has been executed. The thunk returns a list of flat `out_axes`, expanded to the output pytree. However, the thunking presents us with two problems: *** Transformations *** Each transformation that modifies the number of outputs needs to ensure that the thunk is updated to reflect the new values. To make things worse a lot of the transforms can learn the number of added outputs _only after the wrapped function is evaluated_, which leads to the following "time travel" pattern that can be found in most `Trace`s: ```py @lu.transformation_with_aux def compute_output_statistic(*args, **kwargs): outputs = yield args, kwargs yield outputs, compute_statistic(outputs) wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun) def new_out_axes_thunk(): old_out_axes = params['out_axes_thunk']() return compute_new_out_axes(old_out_axes(), output_statistic()) primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk)) ``` The reason why we have to structure the code this way is that we can only specify a new `out_axes_thunk` before we bind the primitive, but we need the outputs of bind to know how to update the `out_axes_thunk`. To make things worse, the implementation of `bind` is allowed to make a call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_. This means that we cannot compute the output statistic in the implementation of the transformation, but we have to use an extra `lu.transformation_with_aux` for that (this populates the statistic store immediately after `wrapped_fun` is evaluated). The `compute_statistic` function depends on the transform in question. E.g. in the JVP trace it counts the number of non-zero tangent results. The situation is of course further complicated when we take `post_process_map` into account. The new `process_env_traces` now always sets up this funny time travel trampoline just in case it ends up being necessary, and `post_process_map` is now expected to return `(outputs, (todo, out_axes_transform))` instead of just `(outputs, todo)`. *** Compilation cache *** Because the `out_axes_thunk`s are now arguments to a _global_ compilation cache (in the form of `lu.cache` decorator on `parallel_callable`), we have to ensure that they implement `hash` and `==`. This is what forces us to add some slightly weird helpers such as `_hashable_function` and `_ignore_elem_list`. The code that uses those makes an assumption that the output pytree depends deterministically on the identity of the wrapped function, which I think is in line with general JAX assumptions. Otherwise the cache would depend on the identity of the thunk, which changes with every function invocation. Relaxing the global constraint on the cache (e.g. allowing each `pmap(f)` instance to have a separate cache) would make this easier too. * Why final style? * Now, making the primitives initial-style would remove the necessity for thunking, because we could have obtained the output pytree right when the function is wrapped. I assumed there is a good argument for making `pmap` pretend that it's a final-style primitive, but I'm not sure why that is? I hope it's something better than just avoiding a single jaxpr tracing.
2020-11-09 17:23:16 +00:00
return [BatchTracer(trace, v, d + 1 if both_mapped(out_axis, d) and out_axis <= d else d)
for v, d, out_axis in zip(vals, dims, params['out_axes_thunk']())]
if call_primitive.map_primitive:
def out_axes_transform(out_axes):
return tuple(out_axis + 1 if both_mapped(out_axis, d) and d < out_axis else out_axis
for out_axis, d in zip(out_axes, dims))
todo = (todo, out_axes_transform)
return vals, todo
2018-11-17 18:03:33 -08:00
def process_custom_jvp_call(self, prim, fun, jvp, tracers):
in_vals, in_dims = unzip2((t.val, t.batch_dim) for t in tracers)
fun, out_dims1 = batch_subtrace(fun, self.main, in_dims)
jvp, out_dims2 = batch_custom_jvp_subtrace(jvp, self.main, in_dims)
out_vals = prim.bind(fun, jvp, *in_vals)
fst, out_dims = lu.merge_linear_aux(out_dims1, out_dims2)
if not fst:
assert out_dims == out_dims[:len(out_dims) // 2] * 2
out_dims = out_dims[:len(out_dims) // 2]
return [BatchTracer(self, v, d) for v, d in zip(out_vals, out_dims)]
def post_process_custom_jvp_call(self, out_tracers, params):
vals, dims = unzip2((t.val, t.batch_dim) for t in out_tracers)
main = self.main
def todo(vals):
trace = main.with_cur_sublevel()
return map(partial(BatchTracer, trace), vals, dims)
return vals, todo
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, *, out_trees):
in_vals, in_dims = unzip2((t.val, t.batch_dim) for t in tracers)
fun, out_dims1 = batch_subtrace(fun, self.main, in_dims)
fwd, out_dims2 = batch_subtrace(fwd, self.main, in_dims)
# TODO(mattjj,apaszke): support collectives in custom_vjp?
bwd = batch_fun(bwd, out_dims2, in_dims,
axis_name='__unused_axis_name', sum_match=True)
out_vals = prim.bind(fun, fwd, bwd, *in_vals, out_trees=out_trees)
fst, out_dims = lu.merge_linear_aux(out_dims1, out_dims2)
if not fst:
out_dims = out_dims[-len(out_vals) % len(out_dims):]
return [BatchTracer(self, v, d) for v, d in zip(out_vals, out_dims)]
def _main_trace_for_axis_names(main_trace: core.MainTrace,
axis_name: Union[core.AxisName, Tuple[core.AxisName, ...]]
) -> bool:
# This function exists to identify whether a main trace corresponds to any of
# the axis names used by a primitive. Axis names alone aren't enough because
# axis names can shadow, so we use the main trace as a tag.
if not isinstance(axis_name, (list, tuple)):
axis_name = (axis_name,)
return any(main_trace is core.axis_frame(n).main_trace for n in axis_name)
2018-11-17 18:03:33 -08:00
### primitives
BatchingRule = Callable[..., Tuple[Any, Union[int, Tuple[int, ...]]]]
primitive_batchers : Dict[core.Primitive, BatchingRule] = {}
initial_style_batchers : Dict[core.Primitive, Any] = {}
2018-11-17 18:03:33 -08:00
def get_primitive_batcher(p, axis_name):
if p in initial_style_batchers:
return partial(initial_style_batchers[p], axis_name=axis_name)
2018-11-17 18:03:33 -08:00
try:
return primitive_batchers[p]
except KeyError as err:
2019-10-30 17:31:37 -07:00
msg = "Batching rule for '{}' not implemented"
raise NotImplementedError(msg.format(p)) from err
2018-11-17 18:03:33 -08:00
def defvectorized(prim):
primitive_batchers[prim] = partial(vectorized_batcher, prim)
def vectorized_batcher(prim, batched_args, batch_dims, **params):
assert all(batch_dims[0] == bd for bd in batch_dims[1:]), batch_dims
2018-11-17 18:03:33 -08:00
return prim.bind(*batched_args, **params), batch_dims[0]
def defbroadcasting(prim):
primitive_batchers[prim] = partial(broadcast_batcher, prim)
def broadcast_batcher(prim, args, dims, **params):
"""Process a primitive with built-in broadcasting.
Args:
args: the possibly-batched arguments
dims: list or tuple of the same length as `args`, where each
entry indicates the batching state of the corresponding entry to `args`:
either an int indicating the batch dimension, or else `not_mapped`
indicating no batching.
"""
shapes = {(x.shape, d) for x, d in zip(args, dims) if np.ndim(x)}
if len(shapes) == 1:
# if there's only agreeing batch dims and scalars, just call the primitive
d = next(d for d in dims if d is not not_mapped)
out = prim.bind(*args, **params)
return (out, (d,) * len(out)) if prim.multiple_results else (out, d)
else:
size, = {shape[d] for shape, d in shapes if d is not not_mapped}
args = [bdim_at_front(x, d, size) for x, d in zip(args, dims)]
ndim = max(np.ndim(x) for x in args) # special-case scalar broadcasting
args = [_handle_scalar_broadcasting(ndim, x, d) for x, d in zip(args, dims)]
out = prim.bind(*args, **params)
return (out, (0,) * len(out)) if prim.multiple_results else (out, 0)
def _handle_scalar_broadcasting(nd, x, d):
if d is not_mapped or nd == np.ndim(x):
return x
else:
return x.reshape(x.shape + (1,) * (nd - np.ndim(x)))
2018-11-17 18:03:33 -08:00
def defreducer(prim):
primitive_batchers[prim] = partial(reducer_batcher, prim)
def reducer_batcher(prim, batched_args, batch_dims, axes, **params):
2018-11-17 18:03:33 -08:00
operand, = batched_args
bdim, = batch_dims
axes = tuple(np.where(np.less(axes, bdim), axes, np.add(axes, 1)))
bdim_out = int(list(np.delete(np.arange(operand.ndim), axes)).index(bdim))
if 'input_shape' in params:
params = dict(params, input_shape=operand.shape)
return prim.bind(operand, axes=axes, **params), bdim_out
2018-11-17 18:03:33 -08:00
# sets up primitive batchers for ad_util and xla primitives
2018-11-17 18:03:33 -08:00
def add_batched(batched_args, batch_dims):
bdx, bdy = batch_dims
x, y = batched_args
if bdx == bdy or core.get_aval(x) == core.abstract_unit:
return add_jaxvals(x, y), bdx
elif bdx is not_mapped:
x = broadcast(x, y.shape[bdy], bdy)
return add_jaxvals(x, y), bdy
elif bdy is not_mapped:
y = broadcast(y, x.shape[bdx], bdx)
return add_jaxvals(x, y), bdx
2018-11-17 18:03:33 -08:00
else:
x = moveaxis(x, bdx, bdy)
return add_jaxvals(x, y), bdy
2018-11-17 18:03:33 -08:00
primitive_batchers[add_jaxvals_p] = add_batched
def zeros_like_batched(batched_args, batch_dims):
val, = batched_args
bdim, = batch_dims
return zeros_like_jaxval(val), bdim
primitive_batchers[zeros_like_p] = zeros_like_batched
defvectorized(xla.device_put_p)
2018-11-17 18:03:33 -08:00
### util
def broadcast(x, sz, axis):
if core.get_aval(x) is core.abstract_unit:
return core.unit
shape = list(np.shape(x))
shape.insert(axis, sz)
broadcast_dims = tuple(np.delete(np.arange(len(shape)), axis))
return jax.lax.broadcast_in_dim(x, shape, broadcast_dims)
2020-03-29 20:51:51 -07:00
def matchaxis(sz, src, dst, x, sum_match=False):
if core.get_aval(x) is core.abstract_unit:
return core.unit
2019-05-15 07:25:03 -07:00
if src == dst:
return x
elif type(src) == type(dst) == int:
return moveaxis(x, src, dst)
elif src is not_mapped and dst is not not_mapped:
return broadcast(
x, sz, canonicalize_axis(dst, np.ndim(x) + 1))
2020-03-29 20:51:51 -07:00
elif dst is None and sum_match:
return x.sum(src)
2019-05-15 07:25:03 -07:00
else:
raise ValueError((src, dst))
2019-05-15 07:25:03 -07:00
def bdim_at_front(x, bdim, size):
if core.get_aval(x) is core.abstract_unit:
return core.unit
if bdim is not_mapped:
return broadcast(x, size, 0)
2019-05-15 07:25:03 -07:00
else:
return moveaxis(x, bdim, 0)
2019-05-15 07:25:03 -07:00
def _promote_aval_rank(sz, aval):
if aval is core.abstract_unit:
return core.abstract_unit
2019-05-15 07:25:03 -07:00
else:
return ShapedArray((sz,) + aval.shape, aval.dtype)
2019-05-15 07:25:03 -07:00
def batch_jaxpr(closed_jaxpr, size, batched, instantiate, axis_name):
f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr))
f, batched_out = batched_traceable(f, size, batched, instantiate, axis_name)
avals_in = [_promote_aval_rank(size, a) if b else a
for a, b in zip(closed_jaxpr.in_avals, batched)]
jaxpr_out, _, consts = pe.trace_to_jaxpr_dynamic(f, avals_in)
return core.ClosedJaxpr(jaxpr_out, consts), batched_out()
2019-05-15 07:25:03 -07:00
@lu.transformation_with_aux
def batched_traceable(size, batched, instantiate, axis_name, *vals):
in_dims = [0 if b else None for b in batched]
with core.new_main(BatchTrace, axis_name=axis_name) as main:
with core.extend_axis_env(axis_name, size, main):
trace = main.with_cur_sublevel()
ans = yield map(partial(BatchTracer, trace), vals, in_dims), {}
out_tracers = map(trace.full_raise, ans)
out_vals, out_dims = unzip2((t.val, t.batch_dim) for t in out_tracers)
del main, out_tracers
if type(instantiate) is bool:
instantiate = [instantiate] * len(out_vals)
out_vals = [moveaxis(x, d, 0) if d is not not_mapped and d != 0
else broadcast(x, size, 0) if d is not_mapped and inst else x
for x, d, inst in zip(out_vals, out_dims, instantiate)]
out_batched = [d is not not_mapped or inst
for d, inst in zip(out_dims, instantiate)]
yield out_vals, out_batched
@lu.transformation_with_aux
def batch_custom_jvp_subtrace(main, in_dims, *in_vals):
size, = {x.shape[d] for x, d in zip(in_vals, in_dims) if d is not not_mapped}
trace = main.with_cur_sublevel()
in_tracers = [BatchTracer(trace, val, dim) if dim is not None else val
for val, dim in zip(in_vals, in_dims * 2)]
outs = yield in_tracers, {}
out_tracers = map(trace.full_raise, outs)
out_vals, out_dims = unzip2((t.val, t.batch_dim) for t in out_tracers)
out_primals, out_tangents = split_list(out_vals, [len(out_vals) // 2])
out_primal_bds, out_tangent_bds = split_list(out_dims, [len(out_vals) // 2])
out_dims = map(_merge_bdims, out_primal_bds, out_tangent_bds)
out_primals = map(partial(matchaxis, size), out_primal_bds, out_dims, out_primals)
out_tangents = map(partial(matchaxis, size), out_tangent_bds, out_dims, out_tangents)
yield out_primals + out_tangents, out_dims * 2
def _merge_bdims(x, y):
if x == y:
return x
elif x is not_mapped:
return y
elif y is not_mapped:
return x
else:
return x # arbitrary
2020-09-15 08:06:46 -07:00
@config.register_omnistaging_disabler
def omnistaging_disabler() -> None:
global batch_jaxpr
def batch_jaxpr(jaxpr, size, batched, instantiate, axis_name):
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
f, batched_out = batched_traceable(f, size, batched, instantiate, axis_name)
avals_in = [_promote_aval_rank(size, a) if b else a
for a, b in zip(jaxpr.in_avals, batched)]
2020-09-15 08:06:46 -07:00
in_pvals = [pe.PartialVal.unknown(aval) for aval in avals_in]
jaxpr_out, pvals_out, consts_out = pe.trace_to_jaxpr(f, in_pvals, instantiate=True)
avals_out, _ = unzip2(pvals_out)
return core.ClosedJaxpr(jaxpr_out, consts_out), batched_out()
collective_rules: Dict[core.Primitive, Callable] = {}