2018-11-17 18:03:33 -08:00
|
|
|
# Copyright 2018 Google LLC
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
2020-07-14 13:05:31 -07:00
|
|
|
import numpy as np
|
2020-06-04 15:27:48 -07:00
|
|
|
from typing import Any, Callable, Dict, Optional, Tuple, Union
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-06-08 13:46:10 -07:00
|
|
|
import jax
|
2020-07-30 12:59:36 -07:00
|
|
|
from ..config import config
|
2018-11-17 18:03:33 -08:00
|
|
|
from .. import core
|
2020-11-18 21:17:02 -05:00
|
|
|
from ..core import ShapedArray, raise_to_shaped, Trace, Tracer
|
2019-07-27 15:46:14 -07:00
|
|
|
from ..ad_util import add_jaxvals, add_jaxvals_p, zeros_like_jaxval, zeros_like_p
|
2020-01-05 04:35:34 +01:00
|
|
|
from .. import linear_util as lu
|
2020-08-24 20:21:19 -04:00
|
|
|
from ..util import (unzip2, partial, safe_map, wrap_name, split_list,
|
Add support for non-zero (but still not-None) out_axes in pmap
Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`),
but its semantics would match the specification of `out_axes=0` (i.e.
all outputs should be stacked along the first axis). This patch makes it
possible to specify non-zero values for out_axes, but more importantly
it lays down the groundwork for `xmap` which will have to use some
extremely similar (if not the same) code paths.
One thing to note is that when I started this implementation I was also
planning to add support for `out_axes=None`, which would allow us to
stop using the `unbroadcast` hack, and most of the code is written with
that in mind. Unfortunately it turned out that the correct
implementation of the transpose rule for maps that do allow unmapped
outputs would require me to pretty much simulate what avals-with-names
is supposed to achieve. Technically replicated outputs should work
today, for as long as the user does not do reverse-mode AD of `pmap`.
But I decided that it's better to just disable them altogether until we
can get the full and correct behavior.
* Implementation details *
This patch is significantly more involved than the one that implemented
general `in_axes` support. That previous one at least had the foundation
of `mapped_invars` which already behaved pretty similarly to general
`in_axes`. From a quick glance one might think that `out_axes` should
behave similarly to `in_axes`, but it turns out that this is not the
case, at least not if we're interested in keeping those primitives
final-style.
** Thunking **
The biggest difficulty with handling `out_axes` in final style
primitives is that we want to treat them as a prefix of the output
pytree, but we don't know the structure of the output pytree until the
user function is evaluated! And the user function is not evaluated until
we've applied all transforms and reached the impl rule! The solution to
this problem is "straightforward": instead of putting `out_axes` as a
primitive parameter, we bundle an `out_axes_thunk` which can only be
called successfully after the wrapped function has been executed. The
thunk returns a list of flat `out_axes`, expanded to the output pytree.
However, the thunking presents us with two problems:
*** Transformations ***
Each transformation that modifies the number of outputs needs to ensure
that the thunk is updated to reflect the new values. To make things
worse a lot of the transforms can learn the number of added outputs
_only after the wrapped function is evaluated_, which leads to the
following "time travel" pattern that can be found in most `Trace`s:
```py
@lu.transformation_with_aux
def compute_output_statistic(*args, **kwargs):
outputs = yield args, kwargs
yield outputs, compute_statistic(outputs)
wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun)
def new_out_axes_thunk():
old_out_axes = params['out_axes_thunk']()
return compute_new_out_axes(old_out_axes(), output_statistic())
primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk))
```
The reason why we have to structure the code this way is that we can
only specify a new `out_axes_thunk` before we bind the primitive, but we
need the outputs of bind to know how to update the `out_axes_thunk`. To
make things worse, the implementation of `bind` is allowed to make a
call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_.
This means that we cannot compute the output statistic in the
implementation of the transformation, but we have to use an extra
`lu.transformation_with_aux` for that (this populates the statistic
store immediately after `wrapped_fun` is evaluated).
The `compute_statistic` function depends on the transform in question.
E.g. in the JVP trace it counts the number of non-zero tangent results.
The situation is of course further complicated when we take
`post_process_map` into account. The new `process_env_traces` now always
sets up this funny time travel trampoline just in case it ends up being
necessary, and `post_process_map` is now expected to return `(outputs,
(todo, out_axes_transform))` instead of just `(outputs, todo)`.
*** Compilation cache ***
Because the `out_axes_thunk`s are now arguments to a _global_
compilation cache (in the form of `lu.cache` decorator on
`parallel_callable`), we have to ensure that they implement `hash` and
`==`. This is what forces us to add some slightly weird helpers such as
`_hashable_function` and `_ignore_elem_list`. The code that uses those
makes an assumption that the output pytree depends deterministically on
the identity of the wrapped function, which I think is in line with
general JAX assumptions. Otherwise the cache would depend on the
identity of the thunk, which changes with every function invocation.
Relaxing the global constraint on the cache (e.g. allowing each
`pmap(f)` instance to have a separate cache) would make this easier too.
* Why final style? *
Now, making the primitives initial-style would remove the necessity for
thunking, because we could have obtained the output pytree right when
the function is wrapped. I assumed there is a good argument for making
`pmap` pretend that it's a final-style primitive, but I'm not sure why
that is? I hope it's something better than just avoiding a single jaxpr
tracing.
2020-11-09 17:23:16 +00:00
|
|
|
canonicalize_axis, moveaxis, as_hashable_function)
|
2019-04-24 21:31:15 -07:00
|
|
|
from . import xla
|
2019-05-15 07:25:03 -07:00
|
|
|
from . import partial_eval as pe
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
map = safe_map
|
|
|
|
|
|
|
|
|
2020-08-14 18:22:04 +02:00
|
|
|
def batch(fun: lu.WrappedFun, in_vals, in_dims, out_dim_dests, axis_name):
|
2020-01-15 15:00:38 -08:00
|
|
|
# executes a batched version of `fun` following out_dim_dests
|
2020-08-14 18:22:04 +02:00
|
|
|
batched_fun = batch_fun(fun, in_dims, out_dim_dests, axis_name=axis_name)
|
2020-01-15 15:00:38 -08:00
|
|
|
return batched_fun.call_wrapped(*in_vals)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-01-05 04:35:34 +01:00
|
|
|
@lu.transformation_with_aux
|
2020-08-30 12:38:14 +03:00
|
|
|
def batch_subtrace(main, in_dims, *in_vals, **params):
|
2020-10-26 10:11:13 +00:00
|
|
|
trace = main.with_cur_sublevel()
|
2019-10-28 15:20:49 -07:00
|
|
|
in_tracers = [BatchTracer(trace, val, dim) if dim is not None else val
|
|
|
|
for val, dim in zip(in_vals, in_dims)]
|
2020-01-15 15:00:38 -08:00
|
|
|
outs = yield in_tracers, params
|
2019-07-27 15:46:14 -07:00
|
|
|
out_tracers = map(trace.full_raise, outs)
|
|
|
|
out_vals, out_dims = unzip2((t.val, t.batch_dim) for t in out_tracers)
|
|
|
|
yield out_vals, out_dims
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-03-28 14:15:46 -07:00
|
|
|
|
2020-08-14 18:22:04 +02:00
|
|
|
def batch_fun(fun : lu.WrappedFun, in_dims, out_dim_dests, axis_name,
|
|
|
|
sum_match=False):
|
2020-01-15 15:00:38 -08:00
|
|
|
# transformation version of batch, which doesn't call the function
|
|
|
|
fun, out_dims = batch_subtrace(fun)
|
2020-08-14 18:22:04 +02:00
|
|
|
return _batch_fun(fun, axis_name, sum_match, in_dims, out_dims, out_dim_dests)
|
2020-01-15 15:00:38 -08:00
|
|
|
|
|
|
|
@lu.transformation
|
2020-08-14 18:22:04 +02:00
|
|
|
def _batch_fun(axis_name, sum_match, in_dims, out_dims_thunk, out_dim_dests,
|
|
|
|
*in_vals, **params):
|
2020-01-15 15:00:38 -08:00
|
|
|
in_dims = in_dims() if callable(in_dims) else in_dims
|
2020-08-24 20:21:19 -04:00
|
|
|
in_dims = [
|
|
|
|
canonicalize_axis(dim, np.ndim(val)) if isinstance(dim, int) else dim
|
|
|
|
for val, dim in zip(in_vals, in_dims)]
|
2020-11-24 09:58:44 -08:00
|
|
|
axis_size, = {x.shape[d] for x, d in zip(in_vals, in_dims) if d is not not_mapped}
|
2020-10-26 10:11:13 +00:00
|
|
|
with core.new_main(BatchTrace, axis_name=axis_name) as main:
|
2020-11-24 09:58:44 -08:00
|
|
|
with core.extend_axis_env(axis_name, axis_size, main):
|
2020-08-30 12:38:14 +03:00
|
|
|
out_vals = yield (main, in_dims,) + in_vals, params
|
|
|
|
del main
|
2020-01-15 15:00:38 -08:00
|
|
|
out_dim_dests = out_dim_dests() if callable(out_dim_dests) else out_dim_dests
|
2020-03-28 16:50:31 +01:00
|
|
|
out_dims = out_dims_thunk()
|
|
|
|
for od, od_dest in zip(out_dims, out_dim_dests):
|
2020-09-10 09:38:14 -04:00
|
|
|
if od is not None and not isinstance(od_dest, int) and not sum_match:
|
2020-03-28 16:50:31 +01:00
|
|
|
msg = f"vmap has mapped output but out_axes is {od_dest}"
|
|
|
|
raise ValueError(msg)
|
2020-11-24 09:58:44 -08:00
|
|
|
out_vals = map(partial(matchaxis, axis_size, sum_match=sum_match),
|
|
|
|
out_dims, out_dim_dests, out_vals)
|
2020-01-15 15:00:38 -08:00
|
|
|
yield out_vals
|
|
|
|
|
|
|
|
def batch_fun2(fun : lu.WrappedFun, in_dims):
|
|
|
|
# like `batch_fun` but returns output batch dims (so no out_dim_dests)
|
|
|
|
fun, out_dims = batch_subtrace(fun)
|
|
|
|
return _batch_fun2(fun, in_dims), out_dims
|
|
|
|
|
|
|
|
@lu.transformation
|
|
|
|
def _batch_fun2(in_dims, *in_vals, **params):
|
2020-10-26 10:11:13 +00:00
|
|
|
with core.new_main(BatchTrace, axis_name=None) as main:
|
2020-08-30 12:38:14 +03:00
|
|
|
out_vals = yield (main, in_dims,) + in_vals, params
|
|
|
|
del main
|
2020-01-15 15:00:38 -08:00
|
|
|
yield out_vals
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
### tracer
|
|
|
|
|
2019-07-27 15:46:14 -07:00
|
|
|
# TODO(mattjj): use a special sentinel type rather than None
|
|
|
|
NotMapped = type(None)
|
|
|
|
not_mapped = None
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
class BatchTracer(Tracer):
|
2019-01-16 16:51:54 +00:00
|
|
|
__slots__ = ['val', 'batch_dim']
|
|
|
|
|
2020-03-18 17:06:05 -04:00
|
|
|
def __init__(self, trace, val, batch_dim: Optional[int]):
|
|
|
|
assert core.skip_checks or type(batch_dim) in (int, NotMapped) # type: ignore
|
2020-01-29 16:23:27 -05:00
|
|
|
self._trace = trace
|
2018-11-17 18:03:33 -08:00
|
|
|
self.val = val
|
|
|
|
self.batch_dim = batch_dim
|
|
|
|
|
|
|
|
@property
|
|
|
|
def aval(self):
|
2019-07-27 15:46:14 -07:00
|
|
|
aval = raise_to_shaped(core.get_aval(self.val))
|
|
|
|
if self.batch_dim is not_mapped:
|
|
|
|
return aval
|
2018-11-17 18:03:33 -08:00
|
|
|
else:
|
2019-07-27 15:46:14 -07:00
|
|
|
if aval is core.abstract_unit:
|
|
|
|
return aval
|
|
|
|
elif type(aval) is ShapedArray:
|
|
|
|
assert 0 <= self.batch_dim < aval.ndim
|
2020-07-14 13:05:31 -07:00
|
|
|
new_shape = tuple(np.delete(aval.shape, self.batch_dim))
|
2019-07-27 15:46:14 -07:00
|
|
|
return ShapedArray(new_shape, aval.dtype)
|
2019-06-23 15:31:13 -07:00
|
|
|
else:
|
2019-07-27 15:46:14 -07:00
|
|
|
raise TypeError(aval)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def full_lower(self):
|
2019-07-27 15:46:14 -07:00
|
|
|
if self.batch_dim is not_mapped:
|
2018-11-17 18:03:33 -08:00
|
|
|
return core.full_lower(self.val)
|
|
|
|
else:
|
|
|
|
return self
|
|
|
|
|
|
|
|
class BatchTrace(Trace):
|
2020-10-26 10:11:13 +00:00
|
|
|
def __init__(self, *args, axis_name):
|
|
|
|
super().__init__(*args)
|
|
|
|
self.axis_name = axis_name
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
def pure(self, val):
|
2019-07-27 15:46:14 -07:00
|
|
|
return BatchTracer(self, val, not_mapped)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def lift(self, val):
|
2019-07-27 15:46:14 -07:00
|
|
|
return BatchTracer(self, val, not_mapped)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def sublift(self, val):
|
|
|
|
return BatchTracer(self, val.val, val.batch_dim)
|
|
|
|
|
|
|
|
def process_primitive(self, primitive, tracers, params):
|
|
|
|
vals_in, dims_in = unzip2((t.val, t.batch_dim) for t in tracers)
|
2019-07-27 15:46:14 -07:00
|
|
|
if all(bdim is not_mapped for bdim in dims_in):
|
2018-11-17 18:03:33 -08:00
|
|
|
return primitive.bind(*vals_in, **params)
|
2020-11-24 09:58:44 -08:00
|
|
|
if (primitive in collective_rules and
|
|
|
|
_main_trace_for_axis_names(self.main, params['axis_name'])):
|
|
|
|
frame = core.axis_frame(self.axis_name)
|
|
|
|
val_out, dim_out = collective_rules[primitive](frame, vals_in, dims_in, **params)
|
|
|
|
else:
|
|
|
|
batched_primitive = get_primitive_batcher(primitive, self.axis_name)
|
|
|
|
val_out, dim_out = batched_primitive(vals_in, dims_in, **params)
|
2020-08-14 18:22:04 +02:00
|
|
|
if primitive.multiple_results:
|
|
|
|
return map(partial(BatchTracer, self), val_out, dim_out)
|
2018-11-17 18:03:33 -08:00
|
|
|
else:
|
2020-08-14 18:22:04 +02:00
|
|
|
return BatchTracer(self, val_out, dim_out)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-03-09 20:41:01 +01:00
|
|
|
def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params):
|
2019-07-27 15:46:14 -07:00
|
|
|
assert call_primitive.multiple_results
|
2020-01-15 15:00:38 -08:00
|
|
|
params = dict(params, name=wrap_name(params.get('name', f.__name__), 'vmap'))
|
2020-04-21 18:12:02 -07:00
|
|
|
vals, dims = unzip2((t.val, t.batch_dim) for t in tracers)
|
|
|
|
if all(bdim is not_mapped for bdim in dims):
|
|
|
|
return call_primitive.bind(f, *vals, **params)
|
2018-11-17 18:03:33 -08:00
|
|
|
else:
|
2020-08-30 01:16:51 -07:00
|
|
|
f, dims_out = batch_subtrace(f, self.main, dims)
|
2020-04-21 18:12:02 -07:00
|
|
|
vals_out = call_primitive.bind(f, *vals, **params)
|
|
|
|
return [BatchTracer(self, v, d) for v, d in zip(vals_out, dims_out())]
|
|
|
|
|
|
|
|
def post_process_call(self, call_primitive, out_tracers, params):
|
|
|
|
vals, dims = unzip2((t.val, t.batch_dim) for t in out_tracers)
|
2020-08-30 12:38:14 +03:00
|
|
|
main = self.main
|
2020-04-21 18:12:02 -07:00
|
|
|
def todo(vals):
|
2020-10-26 10:11:13 +00:00
|
|
|
trace = main.with_cur_sublevel()
|
2020-04-21 18:12:02 -07:00
|
|
|
return map(partial(BatchTracer, trace), vals, dims)
|
|
|
|
return vals, todo
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-03-09 20:41:01 +01:00
|
|
|
def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params):
|
2019-06-04 18:33:52 -07:00
|
|
|
vals, dims = unzip2((t.val, t.batch_dim) for t in tracers)
|
2019-07-27 15:46:14 -07:00
|
|
|
if all(dim is not_mapped for dim in dims):
|
2019-06-04 18:33:52 -07:00
|
|
|
return map_primitive.bind(f, *vals, **params)
|
|
|
|
else:
|
2020-11-05 11:54:05 +00:00
|
|
|
assert len({x.shape[d] for x, d in zip(vals, dims) if d is not not_mapped}) == 1
|
|
|
|
# The logic for the dimension math below is as follows:
|
|
|
|
# ╔═════════════╦════════════════════════════════════════╦═══════════╗
|
|
|
|
# ║ d / in_axis ║ None ║ int ║
|
|
|
|
# ╠═════════════╬════════════════════════════════════════╩═══════════╣
|
|
|
|
# ║ None ║ No extra axis, so in_axis unaffected ║
|
|
|
|
# ╠═════════════╬════════════════════════════════════════╦═══════════╣
|
|
|
|
# ║ int ║ Not mapped, so batching dim unaffected ║ See below ║
|
|
|
|
# ╚═════════════╩════════════════════════════════════════╩═══════════╝
|
|
|
|
# When both d and in_axis are defined then:
|
|
|
|
# - If `d <= in_axis`, we have to move the `in_axis` one dimension further;
|
|
|
|
# - If `d > in_axis`, we have to decrement `d` (as `in_axis` will get removed).
|
Add support for non-zero (but still not-None) out_axes in pmap
Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`),
but its semantics would match the specification of `out_axes=0` (i.e.
all outputs should be stacked along the first axis). This patch makes it
possible to specify non-zero values for out_axes, but more importantly
it lays down the groundwork for `xmap` which will have to use some
extremely similar (if not the same) code paths.
One thing to note is that when I started this implementation I was also
planning to add support for `out_axes=None`, which would allow us to
stop using the `unbroadcast` hack, and most of the code is written with
that in mind. Unfortunately it turned out that the correct
implementation of the transpose rule for maps that do allow unmapped
outputs would require me to pretty much simulate what avals-with-names
is supposed to achieve. Technically replicated outputs should work
today, for as long as the user does not do reverse-mode AD of `pmap`.
But I decided that it's better to just disable them altogether until we
can get the full and correct behavior.
* Implementation details *
This patch is significantly more involved than the one that implemented
general `in_axes` support. That previous one at least had the foundation
of `mapped_invars` which already behaved pretty similarly to general
`in_axes`. From a quick glance one might think that `out_axes` should
behave similarly to `in_axes`, but it turns out that this is not the
case, at least not if we're interested in keeping those primitives
final-style.
** Thunking **
The biggest difficulty with handling `out_axes` in final style
primitives is that we want to treat them as a prefix of the output
pytree, but we don't know the structure of the output pytree until the
user function is evaluated! And the user function is not evaluated until
we've applied all transforms and reached the impl rule! The solution to
this problem is "straightforward": instead of putting `out_axes` as a
primitive parameter, we bundle an `out_axes_thunk` which can only be
called successfully after the wrapped function has been executed. The
thunk returns a list of flat `out_axes`, expanded to the output pytree.
However, the thunking presents us with two problems:
*** Transformations ***
Each transformation that modifies the number of outputs needs to ensure
that the thunk is updated to reflect the new values. To make things
worse a lot of the transforms can learn the number of added outputs
_only after the wrapped function is evaluated_, which leads to the
following "time travel" pattern that can be found in most `Trace`s:
```py
@lu.transformation_with_aux
def compute_output_statistic(*args, **kwargs):
outputs = yield args, kwargs
yield outputs, compute_statistic(outputs)
wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun)
def new_out_axes_thunk():
old_out_axes = params['out_axes_thunk']()
return compute_new_out_axes(old_out_axes(), output_statistic())
primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk))
```
The reason why we have to structure the code this way is that we can
only specify a new `out_axes_thunk` before we bind the primitive, but we
need the outputs of bind to know how to update the `out_axes_thunk`. To
make things worse, the implementation of `bind` is allowed to make a
call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_.
This means that we cannot compute the output statistic in the
implementation of the transformation, but we have to use an extra
`lu.transformation_with_aux` for that (this populates the statistic
store immediately after `wrapped_fun` is evaluated).
The `compute_statistic` function depends on the transform in question.
E.g. in the JVP trace it counts the number of non-zero tangent results.
The situation is of course further complicated when we take
`post_process_map` into account. The new `process_env_traces` now always
sets up this funny time travel trampoline just in case it ends up being
necessary, and `post_process_map` is now expected to return `(outputs,
(todo, out_axes_transform))` instead of just `(outputs, todo)`.
*** Compilation cache ***
Because the `out_axes_thunk`s are now arguments to a _global_
compilation cache (in the form of `lu.cache` decorator on
`parallel_callable`), we have to ensure that they implement `hash` and
`==`. This is what forces us to add some slightly weird helpers such as
`_hashable_function` and `_ignore_elem_list`. The code that uses those
makes an assumption that the output pytree depends deterministically on
the identity of the wrapped function, which I think is in line with
general JAX assumptions. Otherwise the cache would depend on the
identity of the thunk, which changes with every function invocation.
Relaxing the global constraint on the cache (e.g. allowing each
`pmap(f)` instance to have a separate cache) would make this easier too.
* Why final style? *
Now, making the primitives initial-style would remove the necessity for
thunking, because we could have obtained the output pytree right when
the function is wrapped. I assumed there is a good argument for making
`pmap` pretend that it's a final-style primitive, but I'm not sure why
that is? I hope it's something better than just avoiding a single jaxpr
tracing.
2020-11-09 17:23:16 +00:00
|
|
|
def both_mapped(in_out_axis, d):
|
|
|
|
return in_out_axis is not None and d is not not_mapped
|
|
|
|
new_in_axes = tuple(
|
|
|
|
in_axis + 1 if both_mapped(in_axis, d) and d <= in_axis else in_axis
|
|
|
|
for d, in_axis in zip(dims, params['in_axes']))
|
|
|
|
new_dims = tuple(
|
|
|
|
d - 1 if both_mapped(in_axis, d) and in_axis < d else d
|
|
|
|
for d, in_axis in zip(dims, params['in_axes']))
|
2020-11-05 11:54:05 +00:00
|
|
|
f, dims_out = batch_subtrace(f, self.main, new_dims)
|
Add support for non-zero (but still not-None) out_axes in pmap
Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`),
but its semantics would match the specification of `out_axes=0` (i.e.
all outputs should be stacked along the first axis). This patch makes it
possible to specify non-zero values for out_axes, but more importantly
it lays down the groundwork for `xmap` which will have to use some
extremely similar (if not the same) code paths.
One thing to note is that when I started this implementation I was also
planning to add support for `out_axes=None`, which would allow us to
stop using the `unbroadcast` hack, and most of the code is written with
that in mind. Unfortunately it turned out that the correct
implementation of the transpose rule for maps that do allow unmapped
outputs would require me to pretty much simulate what avals-with-names
is supposed to achieve. Technically replicated outputs should work
today, for as long as the user does not do reverse-mode AD of `pmap`.
But I decided that it's better to just disable them altogether until we
can get the full and correct behavior.
* Implementation details *
This patch is significantly more involved than the one that implemented
general `in_axes` support. That previous one at least had the foundation
of `mapped_invars` which already behaved pretty similarly to general
`in_axes`. From a quick glance one might think that `out_axes` should
behave similarly to `in_axes`, but it turns out that this is not the
case, at least not if we're interested in keeping those primitives
final-style.
** Thunking **
The biggest difficulty with handling `out_axes` in final style
primitives is that we want to treat them as a prefix of the output
pytree, but we don't know the structure of the output pytree until the
user function is evaluated! And the user function is not evaluated until
we've applied all transforms and reached the impl rule! The solution to
this problem is "straightforward": instead of putting `out_axes` as a
primitive parameter, we bundle an `out_axes_thunk` which can only be
called successfully after the wrapped function has been executed. The
thunk returns a list of flat `out_axes`, expanded to the output pytree.
However, the thunking presents us with two problems:
*** Transformations ***
Each transformation that modifies the number of outputs needs to ensure
that the thunk is updated to reflect the new values. To make things
worse a lot of the transforms can learn the number of added outputs
_only after the wrapped function is evaluated_, which leads to the
following "time travel" pattern that can be found in most `Trace`s:
```py
@lu.transformation_with_aux
def compute_output_statistic(*args, **kwargs):
outputs = yield args, kwargs
yield outputs, compute_statistic(outputs)
wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun)
def new_out_axes_thunk():
old_out_axes = params['out_axes_thunk']()
return compute_new_out_axes(old_out_axes(), output_statistic())
primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk))
```
The reason why we have to structure the code this way is that we can
only specify a new `out_axes_thunk` before we bind the primitive, but we
need the outputs of bind to know how to update the `out_axes_thunk`. To
make things worse, the implementation of `bind` is allowed to make a
call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_.
This means that we cannot compute the output statistic in the
implementation of the transformation, but we have to use an extra
`lu.transformation_with_aux` for that (this populates the statistic
store immediately after `wrapped_fun` is evaluated).
The `compute_statistic` function depends on the transform in question.
E.g. in the JVP trace it counts the number of non-zero tangent results.
The situation is of course further complicated when we take
`post_process_map` into account. The new `process_env_traces` now always
sets up this funny time travel trampoline just in case it ends up being
necessary, and `post_process_map` is now expected to return `(outputs,
(todo, out_axes_transform))` instead of just `(outputs, todo)`.
*** Compilation cache ***
Because the `out_axes_thunk`s are now arguments to a _global_
compilation cache (in the form of `lu.cache` decorator on
`parallel_callable`), we have to ensure that they implement `hash` and
`==`. This is what forces us to add some slightly weird helpers such as
`_hashable_function` and `_ignore_elem_list`. The code that uses those
makes an assumption that the output pytree depends deterministically on
the identity of the wrapped function, which I think is in line with
general JAX assumptions. Otherwise the cache would depend on the
identity of the thunk, which changes with every function invocation.
Relaxing the global constraint on the cache (e.g. allowing each
`pmap(f)` instance to have a separate cache) would make this easier too.
* Why final style? *
Now, making the primitives initial-style would remove the necessity for
thunking, because we could have obtained the output pytree right when
the function is wrapped. I assumed there is a good argument for making
`pmap` pretend that it's a final-style primitive, but I'm not sure why
that is? I hope it's something better than just avoiding a single jaxpr
tracing.
2020-11-09 17:23:16 +00:00
|
|
|
out_axes_thunk = params['out_axes_thunk']
|
|
|
|
@as_hashable_function(key=out_axes_thunk)
|
|
|
|
def new_out_axes_thunk():
|
|
|
|
return tuple(out_axis + 1 if both_mapped(out_axis, d) and d < out_axis else out_axis
|
|
|
|
for out_axis, d in zip(out_axes_thunk(), dims_out()))
|
|
|
|
new_params = dict(params, in_axes=new_in_axes, out_axes_thunk=new_out_axes_thunk)
|
|
|
|
vals_out = map_primitive.bind(f, *vals, **new_params)
|
|
|
|
dims_out = (d + 1 if both_mapped(out_axis, d) and out_axis <= d else d
|
|
|
|
for d, out_axis in zip(dims_out(), out_axes_thunk()))
|
2019-09-11 06:01:32 -07:00
|
|
|
return [BatchTracer(self, v, d) for v, d in zip(vals_out, dims_out)]
|
2019-07-27 15:46:14 -07:00
|
|
|
|
2020-04-21 18:12:02 -07:00
|
|
|
def post_process_map(self, call_primitive, out_tracers, params):
|
2019-07-27 15:46:14 -07:00
|
|
|
vals, dims = unzip2((t.val, t.batch_dim) for t in out_tracers)
|
2020-08-30 12:38:14 +03:00
|
|
|
main = self.main
|
Add support for non-zero (but still not-None) out_axes in pmap
Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`),
but its semantics would match the specification of `out_axes=0` (i.e.
all outputs should be stacked along the first axis). This patch makes it
possible to specify non-zero values for out_axes, but more importantly
it lays down the groundwork for `xmap` which will have to use some
extremely similar (if not the same) code paths.
One thing to note is that when I started this implementation I was also
planning to add support for `out_axes=None`, which would allow us to
stop using the `unbroadcast` hack, and most of the code is written with
that in mind. Unfortunately it turned out that the correct
implementation of the transpose rule for maps that do allow unmapped
outputs would require me to pretty much simulate what avals-with-names
is supposed to achieve. Technically replicated outputs should work
today, for as long as the user does not do reverse-mode AD of `pmap`.
But I decided that it's better to just disable them altogether until we
can get the full and correct behavior.
* Implementation details *
This patch is significantly more involved than the one that implemented
general `in_axes` support. That previous one at least had the foundation
of `mapped_invars` which already behaved pretty similarly to general
`in_axes`. From a quick glance one might think that `out_axes` should
behave similarly to `in_axes`, but it turns out that this is not the
case, at least not if we're interested in keeping those primitives
final-style.
** Thunking **
The biggest difficulty with handling `out_axes` in final style
primitives is that we want to treat them as a prefix of the output
pytree, but we don't know the structure of the output pytree until the
user function is evaluated! And the user function is not evaluated until
we've applied all transforms and reached the impl rule! The solution to
this problem is "straightforward": instead of putting `out_axes` as a
primitive parameter, we bundle an `out_axes_thunk` which can only be
called successfully after the wrapped function has been executed. The
thunk returns a list of flat `out_axes`, expanded to the output pytree.
However, the thunking presents us with two problems:
*** Transformations ***
Each transformation that modifies the number of outputs needs to ensure
that the thunk is updated to reflect the new values. To make things
worse a lot of the transforms can learn the number of added outputs
_only after the wrapped function is evaluated_, which leads to the
following "time travel" pattern that can be found in most `Trace`s:
```py
@lu.transformation_with_aux
def compute_output_statistic(*args, **kwargs):
outputs = yield args, kwargs
yield outputs, compute_statistic(outputs)
wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun)
def new_out_axes_thunk():
old_out_axes = params['out_axes_thunk']()
return compute_new_out_axes(old_out_axes(), output_statistic())
primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk))
```
The reason why we have to structure the code this way is that we can
only specify a new `out_axes_thunk` before we bind the primitive, but we
need the outputs of bind to know how to update the `out_axes_thunk`. To
make things worse, the implementation of `bind` is allowed to make a
call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_.
This means that we cannot compute the output statistic in the
implementation of the transformation, but we have to use an extra
`lu.transformation_with_aux` for that (this populates the statistic
store immediately after `wrapped_fun` is evaluated).
The `compute_statistic` function depends on the transform in question.
E.g. in the JVP trace it counts the number of non-zero tangent results.
The situation is of course further complicated when we take
`post_process_map` into account. The new `process_env_traces` now always
sets up this funny time travel trampoline just in case it ends up being
necessary, and `post_process_map` is now expected to return `(outputs,
(todo, out_axes_transform))` instead of just `(outputs, todo)`.
*** Compilation cache ***
Because the `out_axes_thunk`s are now arguments to a _global_
compilation cache (in the form of `lu.cache` decorator on
`parallel_callable`), we have to ensure that they implement `hash` and
`==`. This is what forces us to add some slightly weird helpers such as
`_hashable_function` and `_ignore_elem_list`. The code that uses those
makes an assumption that the output pytree depends deterministically on
the identity of the wrapped function, which I think is in line with
general JAX assumptions. Otherwise the cache would depend on the
identity of the thunk, which changes with every function invocation.
Relaxing the global constraint on the cache (e.g. allowing each
`pmap(f)` instance to have a separate cache) would make this easier too.
* Why final style? *
Now, making the primitives initial-style would remove the necessity for
thunking, because we could have obtained the output pytree right when
the function is wrapped. I assumed there is a good argument for making
`pmap` pretend that it's a final-style primitive, but I'm not sure why
that is? I hope it's something better than just avoiding a single jaxpr
tracing.
2020-11-09 17:23:16 +00:00
|
|
|
def both_mapped(in_out_axis, d):
|
|
|
|
return in_out_axis is not None and d is not not_mapped
|
2020-04-21 18:12:02 -07:00
|
|
|
def todo(vals):
|
2020-10-26 10:11:13 +00:00
|
|
|
trace = main.with_cur_sublevel()
|
Add support for non-zero (but still not-None) out_axes in pmap
Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`),
but its semantics would match the specification of `out_axes=0` (i.e.
all outputs should be stacked along the first axis). This patch makes it
possible to specify non-zero values for out_axes, but more importantly
it lays down the groundwork for `xmap` which will have to use some
extremely similar (if not the same) code paths.
One thing to note is that when I started this implementation I was also
planning to add support for `out_axes=None`, which would allow us to
stop using the `unbroadcast` hack, and most of the code is written with
that in mind. Unfortunately it turned out that the correct
implementation of the transpose rule for maps that do allow unmapped
outputs would require me to pretty much simulate what avals-with-names
is supposed to achieve. Technically replicated outputs should work
today, for as long as the user does not do reverse-mode AD of `pmap`.
But I decided that it's better to just disable them altogether until we
can get the full and correct behavior.
* Implementation details *
This patch is significantly more involved than the one that implemented
general `in_axes` support. That previous one at least had the foundation
of `mapped_invars` which already behaved pretty similarly to general
`in_axes`. From a quick glance one might think that `out_axes` should
behave similarly to `in_axes`, but it turns out that this is not the
case, at least not if we're interested in keeping those primitives
final-style.
** Thunking **
The biggest difficulty with handling `out_axes` in final style
primitives is that we want to treat them as a prefix of the output
pytree, but we don't know the structure of the output pytree until the
user function is evaluated! And the user function is not evaluated until
we've applied all transforms and reached the impl rule! The solution to
this problem is "straightforward": instead of putting `out_axes` as a
primitive parameter, we bundle an `out_axes_thunk` which can only be
called successfully after the wrapped function has been executed. The
thunk returns a list of flat `out_axes`, expanded to the output pytree.
However, the thunking presents us with two problems:
*** Transformations ***
Each transformation that modifies the number of outputs needs to ensure
that the thunk is updated to reflect the new values. To make things
worse a lot of the transforms can learn the number of added outputs
_only after the wrapped function is evaluated_, which leads to the
following "time travel" pattern that can be found in most `Trace`s:
```py
@lu.transformation_with_aux
def compute_output_statistic(*args, **kwargs):
outputs = yield args, kwargs
yield outputs, compute_statistic(outputs)
wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun)
def new_out_axes_thunk():
old_out_axes = params['out_axes_thunk']()
return compute_new_out_axes(old_out_axes(), output_statistic())
primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk))
```
The reason why we have to structure the code this way is that we can
only specify a new `out_axes_thunk` before we bind the primitive, but we
need the outputs of bind to know how to update the `out_axes_thunk`. To
make things worse, the implementation of `bind` is allowed to make a
call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_.
This means that we cannot compute the output statistic in the
implementation of the transformation, but we have to use an extra
`lu.transformation_with_aux` for that (this populates the statistic
store immediately after `wrapped_fun` is evaluated).
The `compute_statistic` function depends on the transform in question.
E.g. in the JVP trace it counts the number of non-zero tangent results.
The situation is of course further complicated when we take
`post_process_map` into account. The new `process_env_traces` now always
sets up this funny time travel trampoline just in case it ends up being
necessary, and `post_process_map` is now expected to return `(outputs,
(todo, out_axes_transform))` instead of just `(outputs, todo)`.
*** Compilation cache ***
Because the `out_axes_thunk`s are now arguments to a _global_
compilation cache (in the form of `lu.cache` decorator on
`parallel_callable`), we have to ensure that they implement `hash` and
`==`. This is what forces us to add some slightly weird helpers such as
`_hashable_function` and `_ignore_elem_list`. The code that uses those
makes an assumption that the output pytree depends deterministically on
the identity of the wrapped function, which I think is in line with
general JAX assumptions. Otherwise the cache would depend on the
identity of the thunk, which changes with every function invocation.
Relaxing the global constraint on the cache (e.g. allowing each
`pmap(f)` instance to have a separate cache) would make this easier too.
* Why final style? *
Now, making the primitives initial-style would remove the necessity for
thunking, because we could have obtained the output pytree right when
the function is wrapped. I assumed there is a good argument for making
`pmap` pretend that it's a final-style primitive, but I'm not sure why
that is? I hope it's something better than just avoiding a single jaxpr
tracing.
2020-11-09 17:23:16 +00:00
|
|
|
return [BatchTracer(trace, v, d + 1 if both_mapped(out_axis, d) and out_axis <= d else d)
|
|
|
|
for v, d, out_axis in zip(vals, dims, params['out_axes_thunk']())]
|
|
|
|
if call_primitive.map_primitive:
|
|
|
|
def out_axes_transform(out_axes):
|
|
|
|
return tuple(out_axis + 1 if both_mapped(out_axis, d) and d < out_axis else out_axis
|
|
|
|
for out_axis, d in zip(out_axes, dims))
|
|
|
|
todo = (todo, out_axes_transform)
|
2019-07-27 15:46:14 -07:00
|
|
|
return vals, todo
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-03-28 14:15:46 -07:00
|
|
|
def process_custom_jvp_call(self, prim, fun, jvp, tracers):
|
|
|
|
in_vals, in_dims = unzip2((t.val, t.batch_dim) for t in tracers)
|
2020-08-30 01:16:51 -07:00
|
|
|
fun, out_dims1 = batch_subtrace(fun, self.main, in_dims)
|
|
|
|
jvp, out_dims2 = batch_custom_jvp_subtrace(jvp, self.main, in_dims)
|
2020-03-28 14:15:46 -07:00
|
|
|
out_vals = prim.bind(fun, jvp, *in_vals)
|
|
|
|
fst, out_dims = lu.merge_linear_aux(out_dims1, out_dims2)
|
|
|
|
if not fst:
|
|
|
|
assert out_dims == out_dims[:len(out_dims) // 2] * 2
|
|
|
|
out_dims = out_dims[:len(out_dims) // 2]
|
|
|
|
return [BatchTracer(self, v, d) for v, d in zip(out_vals, out_dims)]
|
|
|
|
|
2020-10-16 00:21:04 -07:00
|
|
|
def post_process_custom_jvp_call(self, out_tracers, params):
|
|
|
|
vals, dims = unzip2((t.val, t.batch_dim) for t in out_tracers)
|
|
|
|
main = self.main
|
|
|
|
def todo(vals):
|
2020-10-26 10:11:13 +00:00
|
|
|
trace = main.with_cur_sublevel()
|
2020-10-16 00:21:04 -07:00
|
|
|
return map(partial(BatchTracer, trace), vals, dims)
|
|
|
|
return vals, todo
|
|
|
|
|
2020-03-28 14:15:46 -07:00
|
|
|
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, *, out_trees):
|
|
|
|
in_vals, in_dims = unzip2((t.val, t.batch_dim) for t in tracers)
|
2020-08-30 01:16:51 -07:00
|
|
|
fun, out_dims1 = batch_subtrace(fun, self.main, in_dims)
|
|
|
|
fwd, out_dims2 = batch_subtrace(fwd, self.main, in_dims)
|
2020-10-16 00:21:04 -07:00
|
|
|
# TODO(mattjj,apaszke): support collectives in custom_vjp?
|
2020-08-14 18:22:04 +02:00
|
|
|
bwd = batch_fun(bwd, out_dims2, in_dims,
|
|
|
|
axis_name='__unused_axis_name', sum_match=True)
|
2020-03-28 14:15:46 -07:00
|
|
|
out_vals = prim.bind(fun, fwd, bwd, *in_vals, out_trees=out_trees)
|
|
|
|
fst, out_dims = lu.merge_linear_aux(out_dims1, out_dims2)
|
|
|
|
if not fst:
|
|
|
|
out_dims = out_dims[-len(out_vals) % len(out_dims):]
|
|
|
|
return [BatchTracer(self, v, d) for v, d in zip(out_vals, out_dims)]
|
|
|
|
|
2020-11-24 09:58:44 -08:00
|
|
|
def _main_trace_for_axis_names(main_trace: core.MainTrace,
|
|
|
|
axis_name: Union[core.AxisName, Tuple[core.AxisName, ...]]
|
|
|
|
) -> bool:
|
|
|
|
# This function exists to identify whether a main trace corresponds to any of
|
|
|
|
# the axis names used by a primitive. Axis names alone aren't enough because
|
|
|
|
# axis names can shadow, so we use the main trace as a tag.
|
|
|
|
if not isinstance(axis_name, (list, tuple)):
|
|
|
|
axis_name = (axis_name,)
|
|
|
|
return any(main_trace is core.axis_frame(n).main_trace for n in axis_name)
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
### primitives
|
|
|
|
|
2020-03-18 17:06:05 -04:00
|
|
|
BatchingRule = Callable[..., Tuple[Any, Union[int, Tuple[int, ...]]]]
|
2020-01-15 15:00:38 -08:00
|
|
|
primitive_batchers : Dict[core.Primitive, BatchingRule] = {}
|
2020-10-26 10:11:13 +00:00
|
|
|
initial_style_batchers : Dict[core.Primitive, Any] = {}
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-10-26 10:11:13 +00:00
|
|
|
def get_primitive_batcher(p, axis_name):
|
|
|
|
if p in initial_style_batchers:
|
|
|
|
return partial(initial_style_batchers[p], axis_name=axis_name)
|
2018-11-17 18:03:33 -08:00
|
|
|
try:
|
|
|
|
return primitive_batchers[p]
|
2020-03-09 22:06:12 +02:00
|
|
|
except KeyError as err:
|
2019-10-30 17:31:37 -07:00
|
|
|
msg = "Batching rule for '{}' not implemented"
|
2020-03-09 22:06:12 +02:00
|
|
|
raise NotImplementedError(msg.format(p)) from err
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def defvectorized(prim):
|
|
|
|
primitive_batchers[prim] = partial(vectorized_batcher, prim)
|
|
|
|
|
|
|
|
def vectorized_batcher(prim, batched_args, batch_dims, **params):
|
2019-07-02 12:18:47 -04:00
|
|
|
assert all(batch_dims[0] == bd for bd in batch_dims[1:]), batch_dims
|
2018-11-17 18:03:33 -08:00
|
|
|
return prim.bind(*batched_args, **params), batch_dims[0]
|
|
|
|
|
|
|
|
def defbroadcasting(prim):
|
|
|
|
primitive_batchers[prim] = partial(broadcast_batcher, prim)
|
|
|
|
|
2019-07-27 15:46:14 -07:00
|
|
|
def broadcast_batcher(prim, args, dims, **params):
|
2020-02-10 11:40:05 +01:00
|
|
|
"""Process a primitive with built-in broadcasting.
|
|
|
|
|
|
|
|
Args:
|
2020-02-13 09:28:01 +01:00
|
|
|
args: the possibly-batched arguments
|
2020-02-14 07:23:09 +01:00
|
|
|
dims: list or tuple of the same length as `args`, where each
|
|
|
|
entry indicates the batching state of the corresponding entry to `args`:
|
|
|
|
either an int indicating the batch dimension, or else `not_mapped`
|
|
|
|
indicating no batching.
|
2020-02-10 11:40:05 +01:00
|
|
|
"""
|
2020-07-14 13:05:31 -07:00
|
|
|
shapes = {(x.shape, d) for x, d in zip(args, dims) if np.ndim(x)}
|
2019-07-27 15:46:14 -07:00
|
|
|
if len(shapes) == 1:
|
|
|
|
# if there's only agreeing batch dims and scalars, just call the primitive
|
|
|
|
d = next(d for d in dims if d is not not_mapped)
|
2019-11-24 13:06:23 -05:00
|
|
|
out = prim.bind(*args, **params)
|
|
|
|
return (out, (d,) * len(out)) if prim.multiple_results else (out, d)
|
2019-07-27 15:46:14 -07:00
|
|
|
else:
|
|
|
|
size, = {shape[d] for shape, d in shapes if d is not not_mapped}
|
|
|
|
args = [bdim_at_front(x, d, size) for x, d in zip(args, dims)]
|
2020-07-14 13:05:31 -07:00
|
|
|
ndim = max(np.ndim(x) for x in args) # special-case scalar broadcasting
|
2019-07-27 15:46:14 -07:00
|
|
|
args = [_handle_scalar_broadcasting(ndim, x, d) for x, d in zip(args, dims)]
|
2019-11-24 13:06:23 -05:00
|
|
|
out = prim.bind(*args, **params)
|
|
|
|
return (out, (0,) * len(out)) if prim.multiple_results else (out, 0)
|
2019-07-27 15:46:14 -07:00
|
|
|
|
|
|
|
def _handle_scalar_broadcasting(nd, x, d):
|
2020-07-14 13:05:31 -07:00
|
|
|
if d is not_mapped or nd == np.ndim(x):
|
2019-07-27 15:46:14 -07:00
|
|
|
return x
|
|
|
|
else:
|
2020-07-14 13:05:31 -07:00
|
|
|
return x.reshape(x.shape + (1,) * (nd - np.ndim(x)))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
def defreducer(prim):
|
|
|
|
primitive_batchers[prim] = partial(reducer_batcher, prim)
|
|
|
|
|
2019-01-10 15:35:15 -08:00
|
|
|
def reducer_batcher(prim, batched_args, batch_dims, axes, **params):
|
2018-11-17 18:03:33 -08:00
|
|
|
operand, = batched_args
|
|
|
|
bdim, = batch_dims
|
2020-07-14 13:05:31 -07:00
|
|
|
axes = tuple(np.where(np.less(axes, bdim), axes, np.add(axes, 1)))
|
|
|
|
bdim_out = int(list(np.delete(np.arange(operand.ndim), axes)).index(bdim))
|
2019-01-10 15:35:15 -08:00
|
|
|
if 'input_shape' in params:
|
|
|
|
params = dict(params, input_shape=operand.shape)
|
|
|
|
return prim.bind(operand, axes=axes, **params), bdim_out
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-07-02 12:18:47 -04:00
|
|
|
# sets up primitive batchers for ad_util and xla primitives
|
2019-01-07 08:34:48 -08:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
def add_batched(batched_args, batch_dims):
|
|
|
|
bdx, bdy = batch_dims
|
2019-05-22 15:59:30 -07:00
|
|
|
x, y = batched_args
|
2019-07-27 15:46:14 -07:00
|
|
|
if bdx == bdy or core.get_aval(x) == core.abstract_unit:
|
|
|
|
return add_jaxvals(x, y), bdx
|
|
|
|
elif bdx is not_mapped:
|
|
|
|
x = broadcast(x, y.shape[bdy], bdy)
|
|
|
|
return add_jaxvals(x, y), bdy
|
|
|
|
elif bdy is not_mapped:
|
|
|
|
y = broadcast(y, x.shape[bdx], bdx)
|
|
|
|
return add_jaxvals(x, y), bdx
|
2018-11-17 18:03:33 -08:00
|
|
|
else:
|
2019-07-27 15:46:14 -07:00
|
|
|
x = moveaxis(x, bdx, bdy)
|
|
|
|
return add_jaxvals(x, y), bdy
|
2018-11-17 18:03:33 -08:00
|
|
|
primitive_batchers[add_jaxvals_p] = add_batched
|
|
|
|
|
2019-01-07 08:34:48 -08:00
|
|
|
def zeros_like_batched(batched_args, batch_dims):
|
|
|
|
val, = batched_args
|
|
|
|
bdim, = batch_dims
|
|
|
|
return zeros_like_jaxval(val), bdim
|
|
|
|
primitive_batchers[zeros_like_p] = zeros_like_batched
|
|
|
|
|
2019-07-02 12:18:47 -04:00
|
|
|
defvectorized(xla.device_put_p)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
### util
|
|
|
|
|
2019-07-27 15:46:14 -07:00
|
|
|
def broadcast(x, sz, axis):
|
|
|
|
if core.get_aval(x) is core.abstract_unit:
|
|
|
|
return core.unit
|
2020-07-14 13:05:31 -07:00
|
|
|
shape = list(np.shape(x))
|
2019-07-27 15:46:14 -07:00
|
|
|
shape.insert(axis, sz)
|
2020-07-14 13:05:31 -07:00
|
|
|
broadcast_dims = tuple(np.delete(np.arange(len(shape)), axis))
|
2020-06-08 13:46:10 -07:00
|
|
|
return jax.lax.broadcast_in_dim(x, shape, broadcast_dims)
|
2019-07-27 15:46:14 -07:00
|
|
|
|
2020-03-29 20:51:51 -07:00
|
|
|
def matchaxis(sz, src, dst, x, sum_match=False):
|
2019-07-27 15:46:14 -07:00
|
|
|
if core.get_aval(x) is core.abstract_unit:
|
|
|
|
return core.unit
|
2019-05-15 07:25:03 -07:00
|
|
|
if src == dst:
|
|
|
|
return x
|
2019-07-27 15:46:14 -07:00
|
|
|
elif type(src) == type(dst) == int:
|
|
|
|
return moveaxis(x, src, dst)
|
|
|
|
elif src is not_mapped and dst is not not_mapped:
|
2020-08-24 20:21:19 -04:00
|
|
|
return broadcast(
|
2020-09-10 09:38:14 -04:00
|
|
|
x, sz, canonicalize_axis(dst, np.ndim(x) + 1))
|
2020-03-29 20:51:51 -07:00
|
|
|
elif dst is None and sum_match:
|
|
|
|
return x.sum(src)
|
2019-05-15 07:25:03 -07:00
|
|
|
else:
|
2019-07-27 15:46:14 -07:00
|
|
|
raise ValueError((src, dst))
|
2019-05-15 07:25:03 -07:00
|
|
|
|
2019-07-27 15:46:14 -07:00
|
|
|
def bdim_at_front(x, bdim, size):
|
|
|
|
if core.get_aval(x) is core.abstract_unit:
|
|
|
|
return core.unit
|
|
|
|
if bdim is not_mapped:
|
|
|
|
return broadcast(x, size, 0)
|
2019-05-15 07:25:03 -07:00
|
|
|
else:
|
2019-07-27 15:46:14 -07:00
|
|
|
return moveaxis(x, bdim, 0)
|
2019-05-15 07:25:03 -07:00
|
|
|
|
|
|
|
|
2019-07-27 15:46:14 -07:00
|
|
|
def _promote_aval_rank(sz, aval):
|
|
|
|
if aval is core.abstract_unit:
|
|
|
|
return core.abstract_unit
|
2019-05-15 07:25:03 -07:00
|
|
|
else:
|
2019-07-27 15:46:14 -07:00
|
|
|
return ShapedArray((sz,) + aval.shape, aval.dtype)
|
2019-05-15 07:25:03 -07:00
|
|
|
|
2020-10-26 10:11:13 +00:00
|
|
|
def batch_jaxpr(closed_jaxpr, size, batched, instantiate, axis_name):
|
2020-10-16 00:21:04 -07:00
|
|
|
f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr))
|
2020-10-26 10:11:13 +00:00
|
|
|
f, batched_out = batched_traceable(f, size, batched, instantiate, axis_name)
|
2019-07-27 15:46:14 -07:00
|
|
|
avals_in = [_promote_aval_rank(size, a) if b else a
|
2020-10-16 00:21:04 -07:00
|
|
|
for a, b in zip(closed_jaxpr.in_avals, batched)]
|
2020-09-18 10:07:13 -07:00
|
|
|
jaxpr_out, _, consts = pe.trace_to_jaxpr_dynamic(f, avals_in)
|
|
|
|
return core.ClosedJaxpr(jaxpr_out, consts), batched_out()
|
2019-05-15 07:25:03 -07:00
|
|
|
|
2020-01-05 04:35:34 +01:00
|
|
|
@lu.transformation_with_aux
|
2020-10-26 10:11:13 +00:00
|
|
|
def batched_traceable(size, batched, instantiate, axis_name, *vals):
|
2019-07-27 15:46:14 -07:00
|
|
|
in_dims = [0 if b else None for b in batched]
|
2020-10-26 10:11:13 +00:00
|
|
|
with core.new_main(BatchTrace, axis_name=axis_name) as main:
|
|
|
|
with core.extend_axis_env(axis_name, size, main):
|
|
|
|
trace = main.with_cur_sublevel()
|
|
|
|
ans = yield map(partial(BatchTracer, trace), vals, in_dims), {}
|
|
|
|
out_tracers = map(trace.full_raise, ans)
|
|
|
|
out_vals, out_dims = unzip2((t.val, t.batch_dim) for t in out_tracers)
|
|
|
|
del main, out_tracers
|
2019-07-27 15:46:14 -07:00
|
|
|
if type(instantiate) is bool:
|
|
|
|
instantiate = [instantiate] * len(out_vals)
|
|
|
|
out_vals = [moveaxis(x, d, 0) if d is not not_mapped and d != 0
|
|
|
|
else broadcast(x, size, 0) if d is not_mapped and inst else x
|
|
|
|
for x, d, inst in zip(out_vals, out_dims, instantiate)]
|
|
|
|
out_batched = [d is not not_mapped or inst
|
|
|
|
for d, inst in zip(out_dims, instantiate)]
|
|
|
|
yield out_vals, out_batched
|
2020-03-28 14:15:46 -07:00
|
|
|
|
|
|
|
|
|
|
|
@lu.transformation_with_aux
|
2020-08-30 12:38:14 +03:00
|
|
|
def batch_custom_jvp_subtrace(main, in_dims, *in_vals):
|
2020-03-28 14:15:46 -07:00
|
|
|
size, = {x.shape[d] for x, d in zip(in_vals, in_dims) if d is not not_mapped}
|
2020-10-26 10:11:13 +00:00
|
|
|
trace = main.with_cur_sublevel()
|
2020-03-28 14:15:46 -07:00
|
|
|
in_tracers = [BatchTracer(trace, val, dim) if dim is not None else val
|
|
|
|
for val, dim in zip(in_vals, in_dims * 2)]
|
|
|
|
outs = yield in_tracers, {}
|
|
|
|
out_tracers = map(trace.full_raise, outs)
|
|
|
|
out_vals, out_dims = unzip2((t.val, t.batch_dim) for t in out_tracers)
|
|
|
|
out_primals, out_tangents = split_list(out_vals, [len(out_vals) // 2])
|
|
|
|
out_primal_bds, out_tangent_bds = split_list(out_dims, [len(out_vals) // 2])
|
|
|
|
out_dims = map(_merge_bdims, out_primal_bds, out_tangent_bds)
|
|
|
|
out_primals = map(partial(matchaxis, size), out_primal_bds, out_dims, out_primals)
|
|
|
|
out_tangents = map(partial(matchaxis, size), out_tangent_bds, out_dims, out_tangents)
|
|
|
|
yield out_primals + out_tangents, out_dims * 2
|
|
|
|
|
|
|
|
def _merge_bdims(x, y):
|
|
|
|
if x == y:
|
|
|
|
return x
|
|
|
|
elif x is not_mapped:
|
|
|
|
return y
|
|
|
|
elif y is not_mapped:
|
|
|
|
return x
|
|
|
|
else:
|
|
|
|
return x # arbitrary
|
2020-07-30 12:59:36 -07:00
|
|
|
|
|
|
|
|
2020-09-15 08:06:46 -07:00
|
|
|
@config.register_omnistaging_disabler
|
|
|
|
def omnistaging_disabler() -> None:
|
2020-07-30 12:59:36 -07:00
|
|
|
global batch_jaxpr
|
|
|
|
|
2020-10-26 10:11:13 +00:00
|
|
|
def batch_jaxpr(jaxpr, size, batched, instantiate, axis_name):
|
2020-07-30 12:59:36 -07:00
|
|
|
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
|
2020-10-26 10:11:13 +00:00
|
|
|
f, batched_out = batched_traceable(f, size, batched, instantiate, axis_name)
|
2020-07-30 12:59:36 -07:00
|
|
|
avals_in = [_promote_aval_rank(size, a) if b else a
|
|
|
|
for a, b in zip(jaxpr.in_avals, batched)]
|
2020-09-15 08:06:46 -07:00
|
|
|
in_pvals = [pe.PartialVal.unknown(aval) for aval in avals_in]
|
|
|
|
jaxpr_out, pvals_out, consts_out = pe.trace_to_jaxpr(f, in_pvals, instantiate=True)
|
|
|
|
avals_out, _ = unzip2(pvals_out)
|
2020-09-18 10:07:13 -07:00
|
|
|
return core.ClosedJaxpr(jaxpr_out, consts_out), batched_out()
|
2020-08-14 18:22:04 +02:00
|
|
|
|
|
|
|
|
|
|
|
collective_rules: Dict[core.Primitive, Callable] = {}
|