2019-01-28 11:13:34 -08:00
|
|
|
# Copyright 2018 Google LLC
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
|
2020-04-23 16:01:05 -07:00
|
|
|
from concurrent.futures import ThreadPoolExecutor
|
2019-02-23 20:34:14 -08:00
|
|
|
from functools import partial
|
2020-06-15 09:10:40 -07:00
|
|
|
import itertools as it
|
2020-06-23 09:29:58 -04:00
|
|
|
import gc
|
2019-12-17 14:44:03 -08:00
|
|
|
import os
|
2019-12-17 16:22:55 -08:00
|
|
|
from random import shuffle
|
2020-06-29 16:22:05 -07:00
|
|
|
from typing import Optional, cast
|
2021-10-04 17:54:18 -07:00
|
|
|
import unittest
|
2021-03-29 13:58:04 -07:00
|
|
|
from unittest import SkipTest
|
2020-06-12 16:10:45 -07:00
|
|
|
import warnings
|
2020-06-23 09:29:58 -04:00
|
|
|
import weakref
|
2019-02-23 20:34:14 -08:00
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
import numpy as np
|
2019-01-28 11:13:34 -08:00
|
|
|
from absl.testing import absltest
|
|
|
|
from absl.testing import parameterized
|
|
|
|
|
2019-11-06 08:36:53 -08:00
|
|
|
import jax
|
2020-05-05 14:59:16 -04:00
|
|
|
import jax.numpy as jnp
|
2021-09-24 07:02:08 -07:00
|
|
|
from jax._src import test_util as jtu
|
2020-01-29 03:04:59 +00:00
|
|
|
from jax import tree_util
|
2019-02-23 20:34:14 -08:00
|
|
|
from jax import lax
|
2021-02-08 20:24:19 -08:00
|
|
|
from jax._src.lax import parallel
|
2021-08-04 14:46:21 -07:00
|
|
|
from jax._src import api as src_api
|
2019-09-11 06:01:32 -07:00
|
|
|
from jax import random
|
2020-11-18 21:17:02 -05:00
|
|
|
from jax.core import ShapedArray
|
2021-01-26 19:38:40 -08:00
|
|
|
from jax import (pmap, soft_pmap, jit, vmap, jvp, grad, make_jaxpr,
|
|
|
|
linearize, device_put)
|
2021-11-22 08:22:10 -08:00
|
|
|
from jax._src import device_array
|
2021-09-23 06:33:25 -07:00
|
|
|
import jax._src.lib
|
|
|
|
from jax._src.lib import xla_bridge
|
2021-01-11 14:20:32 -08:00
|
|
|
from jax._src.util import prod, safe_map
|
2019-03-19 16:54:55 -07:00
|
|
|
from jax.interpreters import pxla
|
2019-07-06 10:00:08 -07:00
|
|
|
from jax.interpreters import xla
|
2019-01-28 11:13:34 -08:00
|
|
|
|
|
|
|
from jax.config import config
|
|
|
|
config.parse_flags_with_absl()
|
|
|
|
|
2019-12-17 14:44:03 -08:00
|
|
|
prev_xla_flags = None
|
|
|
|
|
2020-06-29 16:22:05 -07:00
|
|
|
compatible_shapes = [[(3,)], [(3, 4), (3, 1), (1, 4)], [(2, 3, 4), (2, 1, 4)]]
|
|
|
|
|
Add support for non-zero (but still not-None) out_axes in pmap
Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`),
but its semantics would match the specification of `out_axes=0` (i.e.
all outputs should be stacked along the first axis). This patch makes it
possible to specify non-zero values for out_axes, but more importantly
it lays down the groundwork for `xmap` which will have to use some
extremely similar (if not the same) code paths.
One thing to note is that when I started this implementation I was also
planning to add support for `out_axes=None`, which would allow us to
stop using the `unbroadcast` hack, and most of the code is written with
that in mind. Unfortunately it turned out that the correct
implementation of the transpose rule for maps that do allow unmapped
outputs would require me to pretty much simulate what avals-with-names
is supposed to achieve. Technically replicated outputs should work
today, for as long as the user does not do reverse-mode AD of `pmap`.
But I decided that it's better to just disable them altogether until we
can get the full and correct behavior.
* Implementation details *
This patch is significantly more involved than the one that implemented
general `in_axes` support. That previous one at least had the foundation
of `mapped_invars` which already behaved pretty similarly to general
`in_axes`. From a quick glance one might think that `out_axes` should
behave similarly to `in_axes`, but it turns out that this is not the
case, at least not if we're interested in keeping those primitives
final-style.
** Thunking **
The biggest difficulty with handling `out_axes` in final style
primitives is that we want to treat them as a prefix of the output
pytree, but we don't know the structure of the output pytree until the
user function is evaluated! And the user function is not evaluated until
we've applied all transforms and reached the impl rule! The solution to
this problem is "straightforward": instead of putting `out_axes` as a
primitive parameter, we bundle an `out_axes_thunk` which can only be
called successfully after the wrapped function has been executed. The
thunk returns a list of flat `out_axes`, expanded to the output pytree.
However, the thunking presents us with two problems:
*** Transformations ***
Each transformation that modifies the number of outputs needs to ensure
that the thunk is updated to reflect the new values. To make things
worse a lot of the transforms can learn the number of added outputs
_only after the wrapped function is evaluated_, which leads to the
following "time travel" pattern that can be found in most `Trace`s:
```py
@lu.transformation_with_aux
def compute_output_statistic(*args, **kwargs):
outputs = yield args, kwargs
yield outputs, compute_statistic(outputs)
wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun)
def new_out_axes_thunk():
old_out_axes = params['out_axes_thunk']()
return compute_new_out_axes(old_out_axes(), output_statistic())
primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk))
```
The reason why we have to structure the code this way is that we can
only specify a new `out_axes_thunk` before we bind the primitive, but we
need the outputs of bind to know how to update the `out_axes_thunk`. To
make things worse, the implementation of `bind` is allowed to make a
call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_.
This means that we cannot compute the output statistic in the
implementation of the transformation, but we have to use an extra
`lu.transformation_with_aux` for that (this populates the statistic
store immediately after `wrapped_fun` is evaluated).
The `compute_statistic` function depends on the transform in question.
E.g. in the JVP trace it counts the number of non-zero tangent results.
The situation is of course further complicated when we take
`post_process_map` into account. The new `process_env_traces` now always
sets up this funny time travel trampoline just in case it ends up being
necessary, and `post_process_map` is now expected to return `(outputs,
(todo, out_axes_transform))` instead of just `(outputs, todo)`.
*** Compilation cache ***
Because the `out_axes_thunk`s are now arguments to a _global_
compilation cache (in the form of `lu.cache` decorator on
`parallel_callable`), we have to ensure that they implement `hash` and
`==`. This is what forces us to add some slightly weird helpers such as
`_hashable_function` and `_ignore_elem_list`. The code that uses those
makes an assumption that the output pytree depends deterministically on
the identity of the wrapped function, which I think is in line with
general JAX assumptions. Otherwise the cache would depend on the
identity of the thunk, which changes with every function invocation.
Relaxing the global constraint on the cache (e.g. allowing each
`pmap(f)` instance to have a separate cache) would make this easier too.
* Why final style? *
Now, making the primitives initial-style would remove the necessity for
thunking, because we could have obtained the output pytree right when
the function is wrapped. I assumed there is a good argument for making
`pmap` pretend that it's a final-style primitive, but I'm not sure why
that is? I hope it's something better than just avoiding a single jaxpr
tracing.
2020-11-09 17:23:16 +00:00
|
|
|
def all_bdims(*shapes, pmap):
|
2021-03-29 13:58:04 -07:00
|
|
|
bdims = (it.chain([cast(Optional[int], None)], range(len(shape) + 1))
|
|
|
|
for shape in shapes)
|
2020-06-29 16:22:05 -07:00
|
|
|
return (t for t in it.product(*bdims) if not all(e is None for e in t))
|
|
|
|
|
Add support for non-zero (but still not-None) out_axes in pmap
Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`),
but its semantics would match the specification of `out_axes=0` (i.e.
all outputs should be stacked along the first axis). This patch makes it
possible to specify non-zero values for out_axes, but more importantly
it lays down the groundwork for `xmap` which will have to use some
extremely similar (if not the same) code paths.
One thing to note is that when I started this implementation I was also
planning to add support for `out_axes=None`, which would allow us to
stop using the `unbroadcast` hack, and most of the code is written with
that in mind. Unfortunately it turned out that the correct
implementation of the transpose rule for maps that do allow unmapped
outputs would require me to pretty much simulate what avals-with-names
is supposed to achieve. Technically replicated outputs should work
today, for as long as the user does not do reverse-mode AD of `pmap`.
But I decided that it's better to just disable them altogether until we
can get the full and correct behavior.
* Implementation details *
This patch is significantly more involved than the one that implemented
general `in_axes` support. That previous one at least had the foundation
of `mapped_invars` which already behaved pretty similarly to general
`in_axes`. From a quick glance one might think that `out_axes` should
behave similarly to `in_axes`, but it turns out that this is not the
case, at least not if we're interested in keeping those primitives
final-style.
** Thunking **
The biggest difficulty with handling `out_axes` in final style
primitives is that we want to treat them as a prefix of the output
pytree, but we don't know the structure of the output pytree until the
user function is evaluated! And the user function is not evaluated until
we've applied all transforms and reached the impl rule! The solution to
this problem is "straightforward": instead of putting `out_axes` as a
primitive parameter, we bundle an `out_axes_thunk` which can only be
called successfully after the wrapped function has been executed. The
thunk returns a list of flat `out_axes`, expanded to the output pytree.
However, the thunking presents us with two problems:
*** Transformations ***
Each transformation that modifies the number of outputs needs to ensure
that the thunk is updated to reflect the new values. To make things
worse a lot of the transforms can learn the number of added outputs
_only after the wrapped function is evaluated_, which leads to the
following "time travel" pattern that can be found in most `Trace`s:
```py
@lu.transformation_with_aux
def compute_output_statistic(*args, **kwargs):
outputs = yield args, kwargs
yield outputs, compute_statistic(outputs)
wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun)
def new_out_axes_thunk():
old_out_axes = params['out_axes_thunk']()
return compute_new_out_axes(old_out_axes(), output_statistic())
primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk))
```
The reason why we have to structure the code this way is that we can
only specify a new `out_axes_thunk` before we bind the primitive, but we
need the outputs of bind to know how to update the `out_axes_thunk`. To
make things worse, the implementation of `bind` is allowed to make a
call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_.
This means that we cannot compute the output statistic in the
implementation of the transformation, but we have to use an extra
`lu.transformation_with_aux` for that (this populates the statistic
store immediately after `wrapped_fun` is evaluated).
The `compute_statistic` function depends on the transform in question.
E.g. in the JVP trace it counts the number of non-zero tangent results.
The situation is of course further complicated when we take
`post_process_map` into account. The new `process_env_traces` now always
sets up this funny time travel trampoline just in case it ends up being
necessary, and `post_process_map` is now expected to return `(outputs,
(todo, out_axes_transform))` instead of just `(outputs, todo)`.
*** Compilation cache ***
Because the `out_axes_thunk`s are now arguments to a _global_
compilation cache (in the form of `lu.cache` decorator on
`parallel_callable`), we have to ensure that they implement `hash` and
`==`. This is what forces us to add some slightly weird helpers such as
`_hashable_function` and `_ignore_elem_list`. The code that uses those
makes an assumption that the output pytree depends deterministically on
the identity of the wrapped function, which I think is in line with
general JAX assumptions. Otherwise the cache would depend on the
identity of the thunk, which changes with every function invocation.
Relaxing the global constraint on the cache (e.g. allowing each
`pmap(f)` instance to have a separate cache) would make this easier too.
* Why final style? *
Now, making the primitives initial-style would remove the necessity for
thunking, because we could have obtained the output pytree right when
the function is wrapped. I assumed there is a good argument for making
`pmap` pretend that it's a final-style primitive, but I'm not sure why
that is? I hope it's something better than just avoiding a single jaxpr
tracing.
2020-11-09 17:23:16 +00:00
|
|
|
def out_bdims(shape, pmap):
|
|
|
|
return (d[0] for d in all_bdims(shape, pmap=pmap) if d[0] is not None)
|
|
|
|
|
|
|
|
|
2020-06-29 16:22:05 -07:00
|
|
|
def add_bdim(bdim_size, bdim, shape):
|
|
|
|
shape = list(shape)
|
|
|
|
if bdim is not None:
|
|
|
|
shape.insert(bdim, bdim_size)
|
|
|
|
return tuple(shape)
|
|
|
|
|
|
|
|
def slicer(x, bdim):
|
|
|
|
if bdim is None:
|
|
|
|
return lambda _: x
|
|
|
|
else:
|
|
|
|
return lambda i: lax.index_in_dim(x, i, bdim, keepdims=False)
|
|
|
|
|
|
|
|
def args_slicer(args, bdims):
|
|
|
|
slicers = safe_map(slicer, args, bdims)
|
|
|
|
return lambda i: [sl(i) for sl in slicers]
|
|
|
|
|
2019-12-17 14:44:03 -08:00
|
|
|
# Run all tests with 8 CPU devices.
|
|
|
|
def setUpModule():
|
|
|
|
global prev_xla_flags
|
|
|
|
prev_xla_flags = os.getenv("XLA_FLAGS")
|
|
|
|
flags_str = prev_xla_flags or ""
|
|
|
|
# Don't override user-specified device count, or other XLA flags.
|
|
|
|
if "xla_force_host_platform_device_count" not in flags_str:
|
|
|
|
os.environ["XLA_FLAGS"] = (flags_str +
|
|
|
|
" --xla_force_host_platform_device_count=8")
|
|
|
|
# Clear any cached backends so new CPU backend will pick up the env var.
|
|
|
|
xla_bridge.get_backend.cache_clear()
|
|
|
|
|
|
|
|
# Reset to previous configuration in case other test modules will be run.
|
|
|
|
def tearDownModule():
|
|
|
|
if prev_xla_flags is None:
|
|
|
|
del os.environ["XLA_FLAGS"]
|
|
|
|
else:
|
|
|
|
os.environ["XLA_FLAGS"] = prev_xla_flags
|
|
|
|
xla_bridge.get_backend.cache_clear()
|
2019-01-28 11:13:34 -08:00
|
|
|
|
2020-08-10 19:09:34 +02:00
|
|
|
ignore_jit_of_pmap_warning = partial(
|
|
|
|
jtu.ignore_warning, message=".*jit-of-pmap.*")
|
|
|
|
|
2020-08-28 15:21:50 +00:00
|
|
|
ignore_slow_all_to_all_warning = partial(
|
|
|
|
jtu.ignore_warning, message="all_to_all.*expect significant slowdowns.*")
|
2019-02-23 20:34:14 -08:00
|
|
|
|
2020-11-25 15:23:00 -08:00
|
|
|
ignore_xmap_warning = partial(
|
|
|
|
jtu.ignore_warning, message=".*is an experimental.*")
|
|
|
|
|
2021-08-04 14:46:21 -07:00
|
|
|
|
|
|
|
class PythonPmapTest(jtu.JaxTestCase):
|
|
|
|
|
|
|
|
@property
|
|
|
|
def pmap(self):
|
|
|
|
return src_api._python_pmap
|
|
|
|
|
2021-10-08 10:41:43 -07:00
|
|
|
def testDeviceBufferToArray(self):
|
|
|
|
sda = self.pmap(lambda x: x)(jnp.ones((jax.device_count(), 2)))
|
|
|
|
buf = sda.device_buffers[-1]
|
|
|
|
|
|
|
|
view = jnp.array(buf, copy=False)
|
|
|
|
self.assertArraysEqual(sda[-1], view)
|
|
|
|
self.assertEqual(buf.device(), view.device())
|
|
|
|
self.assertEqual(buf.unsafe_buffer_pointer(), view.unsafe_buffer_pointer())
|
|
|
|
|
|
|
|
copy = jnp.array(buf, copy=True)
|
|
|
|
self.assertArraysEqual(sda[-1], copy)
|
|
|
|
self.assertEqual(buf.device(), copy.device())
|
|
|
|
self.assertNotEqual(buf.unsafe_buffer_pointer(), copy.unsafe_buffer_pointer())
|
|
|
|
|
2019-03-21 07:37:43 -07:00
|
|
|
def _getMeshShape(self, device_mesh_shape):
|
2021-09-23 06:33:25 -07:00
|
|
|
device_count = jax.device_count()
|
2019-03-21 07:37:43 -07:00
|
|
|
if any(size == -1 for size in device_mesh_shape):
|
|
|
|
try:
|
2020-05-05 14:59:16 -04:00
|
|
|
return np.arange(device_count).reshape(device_mesh_shape).shape
|
2020-03-09 22:06:12 +02:00
|
|
|
except ValueError as err:
|
2019-03-21 07:37:43 -07:00
|
|
|
msg = "device mesh shape {} not compatible with device count {}"
|
2020-03-09 22:06:12 +02:00
|
|
|
raise SkipTest(msg.format(device_mesh_shape, device_count)) from err
|
2019-03-21 07:37:43 -07:00
|
|
|
else:
|
|
|
|
if device_count % prod(device_mesh_shape):
|
|
|
|
msg = "device mesh size {} does not divide available device count {}"
|
|
|
|
raise SkipTest(msg.format(prod(device_mesh_shape), device_count))
|
|
|
|
else:
|
|
|
|
return device_mesh_shape
|
|
|
|
|
2019-03-19 16:54:55 -07:00
|
|
|
def testBasic(self):
|
2021-08-04 14:46:21 -07:00
|
|
|
f = self.pmap(lambda x: x - lax.psum(x, 'i'), axis_name='i')
|
2019-03-19 16:54:55 -07:00
|
|
|
|
2021-09-23 06:33:25 -07:00
|
|
|
shape = (jax.device_count(), 4)
|
2020-05-05 14:59:16 -04:00
|
|
|
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
|
|
|
expected = x - np.sum(x, 0)
|
2019-03-19 16:54:55 -07:00
|
|
|
|
|
|
|
ans = f(x)
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2021-11-04 09:21:00 -07:00
|
|
|
def testLowerCompile(self):
|
|
|
|
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
|
|
|
|
shape = (jax.device_count(), 4)
|
|
|
|
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
|
|
|
expected = f(x)
|
2022-03-09 09:56:32 -08:00
|
|
|
lowered = f.lower(x)
|
|
|
|
compiled = lowered.compile()
|
|
|
|
ans = compiled(x)
|
|
|
|
|
2021-11-04 09:21:00 -07:00
|
|
|
self.assertAllClose(ans, expected)
|
|
|
|
|
2022-03-07 02:36:09 -08:00
|
|
|
# It's a pair of: (positional args, as a tuple of their structures, kwargs).
|
2022-03-09 09:56:32 -08:00
|
|
|
for obj in [lowered, compiled]:
|
|
|
|
self.assertFalse(obj._no_kwargs)
|
|
|
|
self.assertEqual(obj.in_tree, jax.tree_flatten(((0,), {}))[1])
|
|
|
|
self.assertEqual(obj.in_avals, ((jax.ShapedArray(x.shape, x.dtype),), {}))
|
2022-03-07 02:36:09 -08:00
|
|
|
|
2021-11-04 09:21:00 -07:00
|
|
|
def testLowerCompileInTreeMismatch(self):
|
|
|
|
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
|
|
|
|
shape = (jax.device_count(), 4)
|
|
|
|
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
|
|
|
f_exe = f.lower(x).compile()
|
|
|
|
self.assertRaisesRegex(
|
|
|
|
TypeError, "function compiled for .*, called with .*",
|
|
|
|
lambda: f_exe([x]))
|
|
|
|
|
|
|
|
def testLowerCompileTrivial(self):
|
|
|
|
f = self.pmap(lambda x: x, axis_name='i')
|
|
|
|
x = np.arange(jax.device_count(), dtype=np.float32)
|
|
|
|
expected = f(x)
|
|
|
|
f_exe = f.lower(x).compile()
|
|
|
|
ans = f_exe(x)
|
|
|
|
self.assertAllClose(ans, expected)
|
|
|
|
|
|
|
|
def testLowerCompileTrivialInTreeMismatch(self):
|
|
|
|
f = self.pmap(lambda x: x, axis_name='i')
|
|
|
|
x = np.arange(jax.device_count(), dtype=np.float32)
|
|
|
|
f_exe = f.lower(x).compile()
|
|
|
|
self.assertRaisesRegex(
|
|
|
|
TypeError, "function compiled for .*, called with .*",
|
|
|
|
lambda: f_exe([x]))
|
|
|
|
|
2021-11-04 09:32:19 -07:00
|
|
|
def testLowerCompileArgTypeMismatch(self):
|
|
|
|
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
|
|
|
|
shape = (jax.device_count(), 4)
|
|
|
|
x = np.arange(prod(shape), dtype=int).reshape(shape)
|
|
|
|
x_f32 = x.astype(jnp.float32)
|
|
|
|
x_i32 = x.astype(jnp.int32)
|
|
|
|
f_exe = f.lower(x_f32).compile()
|
|
|
|
self.assertRaisesRegex(
|
|
|
|
TypeError,
|
|
|
|
"Computation compiled for input types:\n.*float32.*\n"
|
|
|
|
"called with:\n.*int32.*",
|
|
|
|
lambda: f_exe(x_i32))
|
|
|
|
|
2021-11-04 09:21:00 -07:00
|
|
|
def testLowerCompileMultiArg(self):
|
|
|
|
f = self.pmap(lambda x, y: x - lax.pmean(y, 'i'), axis_name='i')
|
|
|
|
shape = (jax.device_count(), 4)
|
|
|
|
x = y = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
|
|
|
expected = f(x, y)
|
|
|
|
f_exe = f.lower(x, y).compile()
|
|
|
|
ans = f_exe(x, y)
|
|
|
|
self.assertAllClose(ans, expected)
|
|
|
|
|
|
|
|
def testLowerCompileTrivialMultiArg(self):
|
|
|
|
f = self.pmap(lambda x, y: (x, y), axis_name='i')
|
|
|
|
x = y = np.arange(jax.device_count(), dtype=np.float32)
|
|
|
|
expected = f(x, y)
|
|
|
|
f_exe = f.lower(x, y).compile()
|
|
|
|
ans = f_exe(x, y)
|
|
|
|
self.assertAllClose(ans, expected)
|
|
|
|
|
2022-01-13 15:42:17 -08:00
|
|
|
def testLowerCompilerIR(self):
|
|
|
|
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
|
|
|
|
shape = (jax.device_count(), 4)
|
|
|
|
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
|
|
|
f = f.lower(x)
|
|
|
|
self.assertIsNotNone(f.compiler_ir())
|
|
|
|
self.assertIsNotNone(f.compiler_ir(dialect='hlo'))
|
|
|
|
self.assertIsNotNone(f.compiler_ir(dialect='mhlo'))
|
|
|
|
|
|
|
|
def testLowerCompileCompilerIR(self):
|
|
|
|
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
|
|
|
|
shape = (jax.device_count(), 4)
|
|
|
|
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
|
|
|
f = f.lower(x).compile()
|
|
|
|
self.assertIsNotNone(f.compiler_ir())
|
|
|
|
|
|
|
|
def testLowerCompileExecutable(self):
|
|
|
|
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
|
|
|
|
shape = (jax.device_count(), 4)
|
|
|
|
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
|
|
|
f = f.lower(x).compile()
|
|
|
|
self.assertIsNotNone(f.runtime_executable())
|
|
|
|
|
2020-01-29 18:10:48 +00:00
|
|
|
def testMean(self):
|
2021-08-04 14:46:21 -07:00
|
|
|
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
|
2020-01-29 18:10:48 +00:00
|
|
|
|
2021-09-23 06:33:25 -07:00
|
|
|
shape = (jax.device_count(), 4)
|
2020-05-05 14:59:16 -04:00
|
|
|
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
|
|
|
expected = x - np.broadcast_to(np.mean(x, 0), x.shape)
|
2020-01-29 18:10:48 +00:00
|
|
|
|
|
|
|
ans = f(x)
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2020-03-19 15:35:00 +00:00
|
|
|
def testGather(self):
|
2021-08-04 14:46:21 -07:00
|
|
|
f = self.pmap(lambda x: lax.all_gather(x, 'i'), axis_name='i')
|
2020-03-19 15:35:00 +00:00
|
|
|
|
2021-09-23 06:33:25 -07:00
|
|
|
shape = (jax.device_count(), 4)
|
2020-05-05 14:59:16 -04:00
|
|
|
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
2021-09-23 06:33:25 -07:00
|
|
|
expected = np.array([x] * jax.device_count())
|
2020-03-19 15:35:00 +00:00
|
|
|
ans = f(x)
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2021-12-03 16:11:27 -08:00
|
|
|
def testGatherBool(self):
|
|
|
|
f = self.pmap(lambda x: lax.all_gather(x, 'i'), axis_name='i')
|
|
|
|
|
|
|
|
shape = (jax.device_count(), 4)
|
|
|
|
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
|
|
|
x = (x % 2).astype(np.bool_)
|
|
|
|
expected = np.array([x] * jax.device_count())
|
|
|
|
ans = f(x)
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2021-07-15 04:22:08 -07:00
|
|
|
def testGatherTiled(self):
|
2021-08-04 14:46:21 -07:00
|
|
|
f = self.pmap(lambda x: lax.all_gather(x, 'i', tiled=True), axis_name='i')
|
2021-07-15 04:22:08 -07:00
|
|
|
|
2021-09-23 06:33:25 -07:00
|
|
|
device_count = jax.device_count()
|
2021-07-15 04:22:08 -07:00
|
|
|
shape = (device_count, 4)
|
|
|
|
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
|
|
|
expected = np.array([x] * device_count).reshape(device_count, -1)
|
|
|
|
ans = f(x)
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2021-07-27 11:12:51 -07:00
|
|
|
def testReduceScatter(self):
|
2021-08-04 14:46:21 -07:00
|
|
|
f = self.pmap(lambda x: lax.psum_scatter(x, 'i'), axis_name='i')
|
2021-07-27 11:12:51 -07:00
|
|
|
|
2021-09-23 06:33:25 -07:00
|
|
|
device_count = jax.device_count()
|
2021-07-27 11:12:51 -07:00
|
|
|
shape = (device_count, device_count)
|
|
|
|
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
|
|
|
expected = np.sum(x, axis=0)
|
|
|
|
ans = f(x)
|
|
|
|
for i, actual in enumerate(ans):
|
|
|
|
self.assertAllClose(actual, expected[i])
|
|
|
|
|
|
|
|
def testReduceScatterTiled(self):
|
2021-08-04 14:46:21 -07:00
|
|
|
f = self.pmap(lambda x: lax.psum_scatter(x, 'i', tiled=True), axis_name='i')
|
2021-07-27 11:12:51 -07:00
|
|
|
|
2021-09-23 06:33:25 -07:00
|
|
|
device_count = jax.device_count()
|
2021-07-27 11:12:51 -07:00
|
|
|
shape = (device_count, 4 * device_count)
|
|
|
|
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
|
|
|
expected = np.sum(x, axis=0)
|
|
|
|
ans = f(x)
|
|
|
|
scatter_len = len(expected) // device_count
|
|
|
|
for i, actual in enumerate(ans):
|
|
|
|
self.assertAllClose(actual,
|
|
|
|
expected[i * scatter_len:(i + 1) * scatter_len])
|
|
|
|
|
|
|
|
def testReduceScatterReplicaGroupsTiled(self):
|
2021-09-23 06:33:25 -07:00
|
|
|
replicas = jax.device_count()
|
2021-07-27 11:12:51 -07:00
|
|
|
if replicas % 2 != 0:
|
|
|
|
raise SkipTest
|
|
|
|
axis_index_groups = [[i for i in range(jax.device_count()) if i % 2 == 0],
|
|
|
|
[i for i in range(jax.device_count()) if i % 2 != 0]]
|
|
|
|
f = lambda x: lax.psum_scatter(
|
|
|
|
x, 'i', axis_index_groups=axis_index_groups, tiled=True)
|
2021-08-04 14:46:21 -07:00
|
|
|
f = self.pmap(f, axis_name='i')
|
2021-07-27 11:12:51 -07:00
|
|
|
|
|
|
|
shape = (replicas, 4 * replicas)
|
|
|
|
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
|
|
|
ans = f(x)
|
|
|
|
|
|
|
|
group_1_result = np.sum(x[0::2,:], axis=0)
|
|
|
|
group_2_result = np.sum(x[1::2,:], axis=0)
|
|
|
|
# the result is scattered over (replicas // 2) devices
|
|
|
|
scatter_len = len(group_1_result) * 2 // replicas
|
|
|
|
|
|
|
|
for i, actual in enumerate(ans):
|
|
|
|
expected = group_1_result if i % 2 == 0 else group_2_result
|
|
|
|
self.assertAllClose(
|
|
|
|
actual, expected[i // 2 * scatter_len:(i // 2 + 1) * scatter_len])
|
|
|
|
|
2020-08-28 15:21:50 +00:00
|
|
|
@ignore_slow_all_to_all_warning()
|
2020-01-29 03:04:59 +00:00
|
|
|
def testTrees(self):
|
|
|
|
ptranspose = lambda x, axis_name: lax.all_to_all(x, axis_name, 0, 0)
|
|
|
|
def protate(x, axis_name):
|
|
|
|
n = lax.psum(1, axis_name)
|
|
|
|
return lax.ppermute(x, axis_name, [(i, (i + 1) % n) for i in range(n)])
|
|
|
|
|
|
|
|
tree_f = lambda f: partial(tree_util.tree_map, f)
|
2021-08-04 14:46:21 -07:00
|
|
|
jax_f = lambda p: self.pmap(lambda x: p(x, 'i'), 'i')
|
2020-05-05 14:59:16 -04:00
|
|
|
np_f = lambda p: tree_f(lambda x: np.broadcast_to(p(x, 0), x.shape))
|
|
|
|
np_transpose = tree_f(np.transpose)
|
|
|
|
np_rotate = tree_f(lambda x: np.concatenate([x[-1:], x[:-1]]))
|
2020-01-29 03:04:59 +00:00
|
|
|
|
2021-09-23 06:33:25 -07:00
|
|
|
n = jax.device_count()
|
2020-05-05 14:59:16 -04:00
|
|
|
x = {'a': np.arange(1 * n * n, 2 * n * n).reshape([n, n]),
|
|
|
|
'b': np.arange(2 * n * n, 3 * n * n).reshape([n, n]),
|
|
|
|
'c': np.arange(4 * n * n, 5 * n * n).reshape([n, n])}
|
2020-01-29 03:04:59 +00:00
|
|
|
|
2022-04-01 14:51:54 -07:00
|
|
|
assert_allclose = partial(tree_util.tree_map,
|
2020-01-29 03:04:59 +00:00
|
|
|
partial(self.assertAllClose, check_dtypes=False))
|
2020-05-05 14:59:16 -04:00
|
|
|
assert_allclose(jax_f(lax.pmax)(x), np_f(np.max)(x))
|
|
|
|
assert_allclose(jax_f(lax.pmin)(x), np_f(np.min)(x))
|
|
|
|
assert_allclose(jax_f(lax.psum)(x), np_f(np.sum)(x))
|
|
|
|
assert_allclose(jax_f(lax.pmean)(x), np_f(np.mean)(x))
|
2020-08-28 15:21:50 +00:00
|
|
|
assert_allclose(jax_f(ptranspose)(x), np_transpose(x))
|
2021-01-06 10:15:28 +00:00
|
|
|
assert_allclose(jax_f(protate)(x), np_rotate(x))
|
2020-01-29 03:04:59 +00:00
|
|
|
|
2020-03-04 11:35:52 -05:00
|
|
|
def testCollectivesWithTreesOfDifferentDtypes(self):
|
|
|
|
n = len(jax.devices())
|
2020-05-05 14:59:16 -04:00
|
|
|
x = {'a': np.arange(1 * n * n, 2 * n * n, dtype=np.float32).reshape([n, n]),
|
|
|
|
'b': np.arange(2 * n * n, 3 * n * n, dtype=np.int32).reshape([n, n]),
|
|
|
|
'c': np.arange(4 * n * n, 5 * n * n, dtype=np.float32).reshape([n, n]),
|
|
|
|
'd': np.arange(6 * n * n, 7 * n * n, dtype=np.int32).reshape([n, n])}
|
2020-03-04 11:35:52 -05:00
|
|
|
tree_f = lambda f: partial(tree_util.tree_map, f)
|
2021-08-04 14:46:21 -07:00
|
|
|
jax_f = lambda p: self.pmap(lambda x: p(x, 'i'), 'i')
|
2020-05-05 14:59:16 -04:00
|
|
|
np_f = lambda p: tree_f(lambda x: np.broadcast_to(p(x, 0), x.shape))
|
2022-04-01 14:51:54 -07:00
|
|
|
assert_allclose = partial(tree_util.tree_map,
|
2020-03-04 11:35:52 -05:00
|
|
|
partial(self.assertAllClose, check_dtypes=False))
|
2020-05-05 14:59:16 -04:00
|
|
|
assert_allclose(jax_f(lax.pmax)(x), np_f(np.max)(x))
|
|
|
|
assert_allclose(jax_f(lax.pmin)(x), np_f(np.min)(x))
|
|
|
|
assert_allclose(jax_f(lax.psum)(x), np_f(np.sum)(x))
|
|
|
|
assert_allclose(jax_f(lax.pmean)(x), np_f(np.mean)(x))
|
2020-03-04 11:35:52 -05:00
|
|
|
|
2019-10-15 22:55:35 +00:00
|
|
|
def testComplexPsum(self):
|
2021-08-04 14:46:21 -07:00
|
|
|
f = self.pmap(lambda x: x - lax.psum(x, 'i'), axis_name='i')
|
2019-10-15 22:55:35 +00:00
|
|
|
|
2021-09-23 06:33:25 -07:00
|
|
|
shape = (jax.device_count(), 4 * 2)
|
2020-05-05 14:59:16 -04:00
|
|
|
x = np.arange(prod(shape), dtype=np.float32).reshape(shape).view(np.complex64)
|
|
|
|
expected = x - np.sum(x, 0)
|
2019-10-15 22:55:35 +00:00
|
|
|
|
|
|
|
ans = f(x)
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2020-09-23 10:45:23 +00:00
|
|
|
@parameterized.named_parameters(
|
|
|
|
{"testcase_name": f"_split={split_axis}_concat={concat_axis}",
|
|
|
|
"split_axis": split_axis, "concat_axis": concat_axis}
|
|
|
|
for split_axis, concat_axis in it.product(range(2), range(2)))
|
2021-03-05 12:24:56 +00:00
|
|
|
@ignore_slow_all_to_all_warning()
|
2020-09-23 10:45:23 +00:00
|
|
|
def testAllToAll(self, split_axis, concat_axis):
|
|
|
|
pmap_in_axis = 0
|
2021-09-23 06:33:25 -07:00
|
|
|
shape = (jax.device_count(),) * 3
|
2020-09-23 10:45:23 +00:00
|
|
|
x = np.arange(np.prod(shape)).reshape(shape)
|
|
|
|
|
2021-08-04 14:46:21 -07:00
|
|
|
@partial(self.pmap, axis_name='i')
|
2020-09-23 10:45:23 +00:00
|
|
|
def f(x):
|
|
|
|
return lax.all_to_all(x, 'i', split_axis, concat_axis)
|
|
|
|
y = f(x)
|
|
|
|
if pmap_in_axis <= split_axis:
|
|
|
|
split_axis += 1
|
|
|
|
ref = jnp.moveaxis(x, (pmap_in_axis, split_axis),
|
|
|
|
(concat_axis + 1, 0))
|
|
|
|
self.assertAllClose(y, ref)
|
|
|
|
|
|
|
|
@parameterized.named_parameters(
|
|
|
|
{"testcase_name": f"_split={split_axis}_concat={concat_axis}",
|
|
|
|
"split_axis": split_axis, "concat_axis": concat_axis}
|
|
|
|
for split_axis, concat_axis in it.product(range(2), range(2)))
|
2021-03-05 12:24:56 +00:00
|
|
|
@ignore_slow_all_to_all_warning()
|
2020-09-23 10:45:23 +00:00
|
|
|
def testAllToAllSplitAxis(self, split_axis, concat_axis):
|
2021-09-23 06:33:25 -07:00
|
|
|
if jax.device_count() < 4:
|
2020-09-23 10:45:23 +00:00
|
|
|
raise SkipTest("test requires at least four devices")
|
|
|
|
pmap_in_axis = 0
|
|
|
|
shape = (4, 4, 4)
|
|
|
|
x = np.arange(np.prod(shape)).reshape(shape)
|
|
|
|
|
2021-08-04 14:46:21 -07:00
|
|
|
@partial(self.pmap, axis_name='i')
|
|
|
|
@partial(self.pmap, axis_name='j')
|
2020-09-23 10:45:23 +00:00
|
|
|
def f(x):
|
|
|
|
return lax.all_to_all(x, ('i', 'j'), split_axis, concat_axis)
|
|
|
|
|
|
|
|
unroll_shape = (2, 2, *shape[1:])
|
|
|
|
x_unroll = x.reshape(unroll_shape)
|
|
|
|
y_unroll = f(x_unroll)
|
|
|
|
y = y_unroll.reshape(shape)
|
|
|
|
|
|
|
|
if pmap_in_axis <= split_axis:
|
|
|
|
split_axis += 1
|
|
|
|
ref = jnp.moveaxis(x, (pmap_in_axis, split_axis),
|
|
|
|
(concat_axis + 1, 0))
|
|
|
|
self.assertAllClose(y, ref)
|
2019-10-15 22:55:35 +00:00
|
|
|
|
2019-03-19 16:54:55 -07:00
|
|
|
def testNestedBasic(self):
|
2019-04-12 16:28:40 -07:00
|
|
|
f = lambda x: lax.psum(lax.psum(x, 'i'), 'j')
|
2021-08-04 14:46:21 -07:00
|
|
|
f = self.pmap(self.pmap(f, 'i'), 'j')
|
2019-03-19 16:54:55 -07:00
|
|
|
|
2019-03-20 17:46:16 -07:00
|
|
|
def sum_and_broadcast(x, axis):
|
2020-05-05 14:59:16 -04:00
|
|
|
return np.repeat(np.sum(x, axis, keepdims=True), x.shape[axis], axis)
|
2019-03-20 17:46:16 -07:00
|
|
|
|
2021-09-23 06:33:25 -07:00
|
|
|
shape = (jax.device_count(), 1, 4)
|
2020-05-05 14:59:16 -04:00
|
|
|
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
2019-03-19 16:54:55 -07:00
|
|
|
|
|
|
|
ans = f(x)
|
2019-03-20 17:46:16 -07:00
|
|
|
expected = sum_and_broadcast(sum_and_broadcast(x, 0), 1)
|
2019-03-19 16:54:55 -07:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2019-10-10 13:13:21 -04:00
|
|
|
def testMismatchedAxisSizes(self):
|
2021-09-23 06:33:25 -07:00
|
|
|
n = jax.device_count()
|
2021-08-04 14:46:21 -07:00
|
|
|
f = self.pmap(lambda x, y: x + y)
|
2019-11-28 08:48:10 +01:00
|
|
|
self.assertRaisesRegex(
|
2019-11-14 16:00:55 -05:00
|
|
|
ValueError,
|
2020-05-01 14:37:13 -07:00
|
|
|
"pmap got inconsistent sizes for array axes to be mapped",
|
2021-12-10 10:32:09 -08:00
|
|
|
lambda: f(self.rng().randn(n), self.rng().randn(n - 1)))
|
2019-10-10 13:13:21 -04:00
|
|
|
|
2019-03-19 16:54:55 -07:00
|
|
|
@parameterized.named_parameters(
|
2020-04-15 12:43:55 -07:00
|
|
|
{"testcase_name": "_mesh={}".format(device_mesh_shape).replace(" ", ""),
|
2019-03-19 16:54:55 -07:00
|
|
|
"device_mesh_shape": device_mesh_shape}
|
|
|
|
for device_mesh_shape in [(1, 1), (2, -1), (-1, 2)])
|
|
|
|
def testNestedShardingAndStacking(self, device_mesh_shape):
|
2019-03-21 07:37:43 -07:00
|
|
|
mesh_shape = self._getMeshShape(device_mesh_shape)
|
2019-03-19 16:54:55 -07:00
|
|
|
|
|
|
|
f = lambda x: x
|
2021-08-04 14:46:21 -07:00
|
|
|
f = self.pmap(self.pmap(f, 'i'), 'j')
|
2019-03-19 16:54:55 -07:00
|
|
|
|
|
|
|
shape = mesh_shape + (4,)
|
2020-05-05 14:59:16 -04:00
|
|
|
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
2019-03-20 17:46:16 -07:00
|
|
|
|
2019-03-19 16:54:55 -07:00
|
|
|
ans = f(x)
|
2019-03-20 17:46:16 -07:00
|
|
|
expected = x
|
2019-03-19 16:54:55 -07:00
|
|
|
self.assertEqual(ans.shape, expected.shape)
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2020-05-01 14:37:13 -07:00
|
|
|
def testPartiallyMapped(self):
|
2021-08-04 14:46:21 -07:00
|
|
|
f = self.pmap(lambda x, y: x, in_axes=(None, 0))
|
|
|
|
g = self.pmap(lambda x, y: x - lax.psum(y, 'i'), axis_name='i', in_axes=(None, 0))
|
2020-05-01 14:37:13 -07:00
|
|
|
|
2021-09-23 06:33:25 -07:00
|
|
|
mesh_shape = (jax.device_count(),)
|
2020-05-01 14:37:13 -07:00
|
|
|
shape = mesh_shape + (4,)
|
2020-05-05 14:59:16 -04:00
|
|
|
x = np.array(3., dtype=np.float32)
|
|
|
|
y = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
2020-05-01 14:37:13 -07:00
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
f_expected = np.broadcast_to(x, mesh_shape)
|
2020-05-01 14:37:13 -07:00
|
|
|
f_ans = f(x, y)
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(f_ans, f_expected)
|
2020-05-01 14:37:13 -07:00
|
|
|
self.assertIsInstance(f_ans, pxla.ShardedDeviceArray)
|
|
|
|
# the output is actually replicated (has the same values in each device buffer)
|
|
|
|
# but out_axes is implicitly 0, so we shouldn't have replication in the
|
|
|
|
# sharding spec.
|
2020-11-06 12:55:17 +00:00
|
|
|
self.assertEmpty([a for a in f_ans.sharding_spec.mesh_mapping
|
|
|
|
if isinstance(a, pxla.Replicated)])
|
2020-05-01 14:37:13 -07:00
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
g_expected = np.broadcast_to(x - np.sum(y, 0, keepdims=True), shape)
|
2020-05-01 14:37:13 -07:00
|
|
|
g_ans = g(x, y)
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(g_ans, g_expected)
|
2020-05-01 14:37:13 -07:00
|
|
|
self.assertIsInstance(g_ans, pxla.ShardedDeviceArray)
|
2020-11-06 12:55:17 +00:00
|
|
|
self.assertEmpty([a for a in g_ans.sharding_spec.mesh_mapping
|
|
|
|
if isinstance(a, pxla.Replicated)])
|
2020-05-01 14:37:13 -07:00
|
|
|
|
2021-01-28 23:14:26 -08:00
|
|
|
def testReplicate(self):
|
|
|
|
base = np.array([3.,4.], dtype=np.float32)
|
2021-09-23 06:33:25 -07:00
|
|
|
num_devices = jax.device_count()
|
2021-01-28 23:14:26 -08:00
|
|
|
replicated = pxla.replicate(base, num_devices, num_devices, in_axis=None)
|
|
|
|
self.assertAllClose(base, replicated)
|
|
|
|
self.assertEmpty([a for a in replicated.sharding_spec.mesh_mapping
|
|
|
|
if not isinstance(a, pxla.Replicated)])
|
|
|
|
|
2020-05-01 14:37:13 -07:00
|
|
|
@parameterized.named_parameters(
|
|
|
|
{"testcase_name": "_mesh={}".format(device_mesh_shape).replace(" ", ""),
|
|
|
|
"device_mesh_shape": device_mesh_shape}
|
|
|
|
for device_mesh_shape in [(1, 1), (2, -1), (-1, 2)])
|
|
|
|
def testPartiallyMappedNested(self, device_mesh_shape):
|
|
|
|
mesh_shape = self._getMeshShape(device_mesh_shape)
|
|
|
|
|
2021-08-04 14:46:21 -07:00
|
|
|
f = self.pmap(lambda x, y: x - lax.psum(y, 'i'), axis_name='i', in_axes=(None, 0))
|
|
|
|
f = self.pmap(f, axis_name='j', in_axes=(None, 0))
|
2020-05-01 14:37:13 -07:00
|
|
|
|
|
|
|
x = 3.
|
2020-05-05 14:59:16 -04:00
|
|
|
y = np.arange(prod(mesh_shape), dtype=np.float32).reshape(mesh_shape)
|
|
|
|
expected = np.broadcast_to(x - np.sum(y, 1, keepdims=True), mesh_shape)
|
2020-05-01 14:37:13 -07:00
|
|
|
|
|
|
|
ans = f(x, y)
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2019-03-19 16:54:55 -07:00
|
|
|
def testJvpAndPartialEval(self):
|
2021-08-04 14:46:21 -07:00
|
|
|
@partial(self.pmap, axis_name='i')
|
2019-03-19 16:54:55 -07:00
|
|
|
def f(x):
|
2020-05-05 14:59:16 -04:00
|
|
|
return jnp.sin(x)
|
2019-03-19 16:54:55 -07:00
|
|
|
|
|
|
|
def splitjvp(x):
|
|
|
|
_, jvp = linearize(f, x)
|
2020-05-05 14:59:16 -04:00
|
|
|
return jvp(jnp.ones_like(x))
|
2019-03-19 16:54:55 -07:00
|
|
|
|
2021-09-23 06:33:25 -07:00
|
|
|
shape = (jax.device_count(), 4)
|
2020-05-05 14:59:16 -04:00
|
|
|
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
|
|
|
expected = np.cos(x)
|
2019-03-19 16:54:55 -07:00
|
|
|
|
|
|
|
ans = splitjvp(x)
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
make_jaxpr(splitjvp)(x) # doesn't crash
|
|
|
|
|
|
|
|
def testGradBasic(self):
|
2021-08-04 14:46:21 -07:00
|
|
|
@partial(self.pmap, axis_name='i')
|
2019-03-19 16:54:55 -07:00
|
|
|
def f(x):
|
2020-05-05 14:59:16 -04:00
|
|
|
return jnp.sin(x)
|
2019-03-19 16:54:55 -07:00
|
|
|
|
2021-09-23 06:33:25 -07:00
|
|
|
shape = (jax.device_count(), 4)
|
2020-05-05 14:59:16 -04:00
|
|
|
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
2019-03-19 16:54:55 -07:00
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
ans = grad(lambda x: jnp.sum(jnp.sin(x)))(x)
|
|
|
|
expected = grad(lambda x: jnp.sum(f(x)))(x)
|
2019-03-19 16:54:55 -07:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2019-11-06 08:36:53 -08:00
|
|
|
def testGradOfPsum(self):
|
2021-08-04 14:46:21 -07:00
|
|
|
@partial(self.pmap, axis_name='i')
|
2019-11-06 08:36:53 -08:00
|
|
|
def f(x):
|
|
|
|
return lax.psum(x, axis_name='i')
|
|
|
|
|
|
|
|
shape = (jax.device_count(), 4)
|
2020-05-05 14:59:16 -04:00
|
|
|
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
2019-11-06 08:36:53 -08:00
|
|
|
jtu.check_grads(f, (x,), 2, ["fwd", "rev"], 1e-2, 1e-2, eps=1.)
|
|
|
|
|
2019-03-19 16:54:55 -07:00
|
|
|
def testGradOfJvp(self):
|
2021-08-04 14:46:21 -07:00
|
|
|
@partial(self.pmap, axis_name='i')
|
2019-03-19 16:54:55 -07:00
|
|
|
def f(x):
|
2020-05-05 14:59:16 -04:00
|
|
|
return jnp.sin(x)
|
2019-03-19 16:54:55 -07:00
|
|
|
|
|
|
|
def splitjvp(x):
|
|
|
|
_, jvp = linearize(f, x)
|
2020-05-05 14:59:16 -04:00
|
|
|
return jvp(jnp.ones_like(x))
|
2019-03-19 16:54:55 -07:00
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
fun = lambda x: jnp.sum(jvp(jnp.sin, (x,), (jnp.ones_like(x),))[1])
|
2019-03-19 16:54:55 -07:00
|
|
|
|
2021-09-23 06:33:25 -07:00
|
|
|
shape = (jax.device_count(), 4)
|
2020-05-05 14:59:16 -04:00
|
|
|
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
2019-03-19 16:54:55 -07:00
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
ans = grad(lambda x: jnp.sum(splitjvp(x)))(x)
|
2019-03-19 16:54:55 -07:00
|
|
|
expected = grad(fun)(x)
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(ans, expected)
|
2019-03-19 16:54:55 -07:00
|
|
|
|
|
|
|
def testTwoArgsGrad(self):
|
|
|
|
def f(x, y):
|
2020-05-05 14:59:16 -04:00
|
|
|
return lax.psum(5. * jnp.cos(x) * jnp.sin(y), 'i')
|
2021-08-04 14:46:21 -07:00
|
|
|
f = self.pmap(f, 'i')
|
2019-03-19 16:54:55 -07:00
|
|
|
|
|
|
|
def g(x, y):
|
2020-05-05 14:59:16 -04:00
|
|
|
tot = jnp.sum(5. * jnp.cos(x) * jnp.sin(y))
|
|
|
|
return tot * jnp.ones_like(x) # broadcast to map like pjit does
|
2019-03-19 16:54:55 -07:00
|
|
|
|
2021-09-23 06:33:25 -07:00
|
|
|
shape = (jax.device_count(), 4)
|
2020-05-05 14:59:16 -04:00
|
|
|
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
2019-03-19 16:54:55 -07:00
|
|
|
y = 4 + x
|
2020-05-05 14:59:16 -04:00
|
|
|
ans = grad(lambda x, y: jnp.sum(g(x, y)))(x, y)
|
|
|
|
expected = grad(lambda x, y: jnp.sum(g(x, y)))(x, y)
|
2019-03-19 16:54:55 -07:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
@parameterized.named_parameters(
|
2020-04-15 12:43:55 -07:00
|
|
|
{"testcase_name": "_mesh={}".format(device_mesh_shape).replace(" ", ""),
|
2019-03-19 16:54:55 -07:00
|
|
|
"device_mesh_shape": device_mesh_shape}
|
|
|
|
for device_mesh_shape in [(1, 1), (2, -1), (-1, 2)])
|
|
|
|
def testNestedWithClosure(self, device_mesh_shape):
|
2019-03-21 07:37:43 -07:00
|
|
|
mesh_shape = self._getMeshShape(device_mesh_shape)
|
2019-02-23 20:34:14 -08:00
|
|
|
|
2021-08-04 14:46:21 -07:00
|
|
|
@partial(self.pmap, axis_name='i')
|
2019-02-23 20:34:14 -08:00
|
|
|
def test_fun(x):
|
2020-05-05 14:59:16 -04:00
|
|
|
y = jnp.sum(jnp.sin(x))
|
2019-02-23 20:34:14 -08:00
|
|
|
|
2021-08-04 14:46:21 -07:00
|
|
|
@partial(self.pmap, axis_name='j')
|
2019-02-23 20:34:14 -08:00
|
|
|
def g(z):
|
2020-05-05 14:59:16 -04:00
|
|
|
return 3. * jnp.exp(jnp.sin(x).sum() * jnp.cos(y) * jnp.tan(z))
|
2019-02-23 20:34:14 -08:00
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
return grad(lambda w: jnp.sum(g(w)))(x)
|
2019-02-23 20:34:14 -08:00
|
|
|
|
|
|
|
@vmap
|
|
|
|
def baseline_fun(x):
|
2020-05-05 14:59:16 -04:00
|
|
|
y = jnp.sum(jnp.sin(x))
|
2019-02-23 20:34:14 -08:00
|
|
|
|
|
|
|
@vmap
|
|
|
|
def g(z):
|
2020-05-05 14:59:16 -04:00
|
|
|
return 3. * jnp.exp(jnp.sin(x).sum() * jnp.cos(y) * jnp.tan(z))
|
2019-02-23 20:34:14 -08:00
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
return grad(lambda w: jnp.sum(g(w)))(x)
|
2019-02-23 20:34:14 -08:00
|
|
|
|
2019-03-19 16:54:55 -07:00
|
|
|
shape = mesh_shape + (4,)
|
2020-05-05 14:59:16 -04:00
|
|
|
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
2019-03-19 16:54:55 -07:00
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
ans = grad(lambda x: jnp.sum(test_fun(x)))(x)
|
|
|
|
expected = grad(lambda x: jnp.sum(baseline_fun(x)))(x)
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(ans, expected, atol=1e-3)
|
2019-02-01 16:59:28 -08:00
|
|
|
|
2019-05-02 22:13:49 -07:00
|
|
|
def testShardedDeviceArrays(self):
|
2019-03-19 16:54:55 -07:00
|
|
|
f = lambda x: 2 * x
|
2021-08-04 14:46:21 -07:00
|
|
|
f = self.pmap(f, axis_name='i')
|
2019-03-19 16:54:55 -07:00
|
|
|
|
2021-09-23 06:33:25 -07:00
|
|
|
shape = (jax.device_count(), 4)
|
2020-05-05 14:59:16 -04:00
|
|
|
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
2019-03-19 16:54:55 -07:00
|
|
|
|
|
|
|
# test that we can pass in and out ShardedDeviceArrays
|
|
|
|
y = f(x)
|
2020-05-05 14:59:16 -04:00
|
|
|
self.assertIsInstance(y, jnp.ndarray)
|
2019-05-02 22:13:49 -07:00
|
|
|
self.assertIsInstance(y, pxla.ShardedDeviceArray)
|
2021-11-22 08:22:10 -08:00
|
|
|
self.assertIsInstance(y, device_array.DeviceArray)
|
2021-09-15 15:12:19 -04:00
|
|
|
self.assertNotIsInstance(y, np.ndarray)
|
2019-03-19 16:54:55 -07:00
|
|
|
self.assertAllClose(y, 2 * x, check_dtypes=False)
|
|
|
|
z = f(y)
|
2019-05-02 22:13:49 -07:00
|
|
|
self.assertIsInstance(z, pxla.ShardedDeviceArray)
|
2021-11-22 08:22:10 -08:00
|
|
|
self.assertIsInstance(z, device_array.DeviceArray)
|
2021-09-15 15:12:19 -04:00
|
|
|
self.assertNotIsInstance(z, np.ndarray)
|
2019-03-19 16:54:55 -07:00
|
|
|
self.assertAllClose(z, 2 * 2 * x, check_dtypes=False)
|
|
|
|
|
|
|
|
# test that we can pass in a regular DeviceArray
|
|
|
|
y = f(device_put(x))
|
2019-05-02 22:13:49 -07:00
|
|
|
self.assertIsInstance(y, pxla.ShardedDeviceArray)
|
2019-03-19 16:54:55 -07:00
|
|
|
self.assertAllClose(y, 2 * x, check_dtypes=False)
|
|
|
|
|
|
|
|
# test that we can pass a ShardedDeviceArray to a regular jit computation
|
|
|
|
z = y + y
|
|
|
|
self.assertAllClose(z, 2 * 2 * x, check_dtypes=False)
|
|
|
|
|
|
|
|
# test that we can handle device movement on dispatch
|
2021-08-10 07:15:46 -07:00
|
|
|
y = pxla.make_sharded_device_array(y.aval, y.sharding_spec,
|
|
|
|
y.device_buffers[::-1])
|
2019-03-19 16:54:55 -07:00
|
|
|
z = f(y)
|
|
|
|
self.assertAllClose(z, 2 * 2 * x[::-1], check_dtypes=False)
|
|
|
|
|
2019-05-02 22:13:49 -07:00
|
|
|
# test that the repr doesn't crash
|
|
|
|
repr(z)
|
|
|
|
|
2022-02-08 09:51:57 -05:00
|
|
|
# test that we can lexically capture a sda as a constant.
|
|
|
|
g = jit(lambda z: z + y)
|
|
|
|
self.assertAllClose(g(7), y + 7)
|
|
|
|
|
|
|
|
|
2020-04-15 18:43:46 -07:00
|
|
|
# Tests edge cases in lax._reshape_sharded_device_array
|
|
|
|
@parameterized.named_parameters(
|
|
|
|
{"testcase_name": "_in={}_out={}".format(in_shape, out_shape)
|
|
|
|
.replace(" ", ""),
|
|
|
|
"in_shape": in_shape, "out_shape": out_shape}
|
|
|
|
for in_shape, out_shape in [
|
|
|
|
[(1,1), (1,)], [(1,), (1,1)], [(1,), ()], [(4,7), (2,2,7)]
|
|
|
|
])
|
|
|
|
def testShardedDeviceArrayReshape(self, in_shape, out_shape):
|
2021-09-23 06:33:25 -07:00
|
|
|
if jax.device_count() < max(in_shape[:1] + out_shape[:1]):
|
2020-04-15 18:43:46 -07:00
|
|
|
raise SkipTest("not enough devices")
|
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
x = np.arange(prod(in_shape)).reshape(in_shape)
|
2021-08-04 14:46:21 -07:00
|
|
|
sharded_x = self.pmap(lambda x: x)(x)
|
2020-04-15 18:43:46 -07:00
|
|
|
self.assertAllClose(sharded_x.reshape(out_shape), x.reshape(out_shape),
|
|
|
|
check_dtypes=False)
|
|
|
|
|
2019-04-01 17:56:23 -07:00
|
|
|
def testPsumMultiple(self):
|
2019-04-12 16:28:40 -07:00
|
|
|
f = lambda x: lax.psum(x, ('i', 'j'))
|
2021-08-04 14:46:21 -07:00
|
|
|
f = self.pmap(self.pmap(f, 'i'), 'j')
|
2019-04-01 17:56:23 -07:00
|
|
|
|
|
|
|
def sum_and_broadcast(x, axis):
|
2020-05-05 14:59:16 -04:00
|
|
|
return np.repeat(np.sum(x, axis, keepdims=True), x.shape[axis], axis)
|
2019-04-01 17:56:23 -07:00
|
|
|
|
2021-09-23 06:33:25 -07:00
|
|
|
device_count = jax.device_count()
|
2019-04-01 17:56:23 -07:00
|
|
|
num_pairs, ragged = divmod(device_count, 2)
|
|
|
|
if num_pairs > 1 and not ragged:
|
|
|
|
shape = (num_pairs, 2, 4)
|
|
|
|
else:
|
|
|
|
shape = (device_count, 1, 4)
|
2020-05-05 14:59:16 -04:00
|
|
|
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
2019-04-01 17:56:23 -07:00
|
|
|
|
|
|
|
ans = f(x)
|
|
|
|
expected = sum_and_broadcast(sum_and_broadcast(x, 0), 1)
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2020-08-14 22:54:36 -07:00
|
|
|
def testPsumConstantReplicaGroups(self):
|
2021-09-23 06:33:25 -07:00
|
|
|
replicas = jax.device_count()
|
2020-08-14 22:54:36 -07:00
|
|
|
if replicas % 2 != 0:
|
|
|
|
raise SkipTest
|
|
|
|
axis_index_groups = np.arange(replicas).reshape(
|
|
|
|
2, replicas // 2).tolist()
|
|
|
|
f = lambda x: x - lax.psum(2., 'i', axis_index_groups=axis_index_groups)
|
2021-08-04 14:46:21 -07:00
|
|
|
f = self.pmap(f, 'i')
|
2020-08-14 22:54:36 -07:00
|
|
|
|
|
|
|
shape = (replicas, 4)
|
|
|
|
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
|
|
|
expected_psum = 2. * replicas // 2
|
|
|
|
expected = x - expected_psum
|
|
|
|
|
|
|
|
ans = f(x)
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
2021-09-01 13:08:10 -07:00
|
|
|
|
|
|
|
@jtu.skip_on_devices("tpu")
|
|
|
|
def testPsumUnevenReplicaGroups(self):
|
2021-09-23 06:33:25 -07:00
|
|
|
replicas = jax.device_count()
|
2021-09-01 13:08:10 -07:00
|
|
|
if replicas <= 2:
|
|
|
|
raise SkipTest("Test expected devices greater than 2.")
|
|
|
|
axis_index_groups = [[0,1], np.arange(2,replicas)]
|
|
|
|
f = lambda x: x - lax.psum(x, 'i', axis_index_groups=axis_index_groups)
|
|
|
|
f = self.pmap(f, 'i')
|
|
|
|
|
|
|
|
shape = (replicas, 4)
|
|
|
|
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
|
|
|
def sum_helper(a):
|
|
|
|
return np.broadcast_to(a.sum(0, keepdims=True),
|
|
|
|
(len(a), x.shape[1]))
|
|
|
|
expected_psum_1 = sum_helper(x[0:2])
|
|
|
|
expected_psum_2 = sum_helper(x[2:])
|
|
|
|
expected_psum = np.concatenate([expected_psum_1, expected_psum_2], 0)
|
|
|
|
expected = x - expected_psum
|
|
|
|
|
|
|
|
ans = f(x)
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
2020-08-14 22:54:36 -07:00
|
|
|
|
2020-05-08 14:00:34 -07:00
|
|
|
def testPsumReplicaGroups(self):
|
2021-09-23 06:33:25 -07:00
|
|
|
replicas = jax.device_count()
|
2020-05-08 14:00:34 -07:00
|
|
|
if replicas % 2 != 0:
|
|
|
|
raise SkipTest
|
2020-05-10 14:25:18 +03:00
|
|
|
axis_index_groups = np.arange(replicas).reshape(
|
2020-05-08 14:00:34 -07:00
|
|
|
2, replicas // 2).tolist()
|
|
|
|
f = lambda x: x - lax.psum(x, 'i', axis_index_groups=axis_index_groups)
|
2021-08-04 14:46:21 -07:00
|
|
|
f = self.pmap(f, 'i')
|
2020-05-08 14:00:34 -07:00
|
|
|
|
|
|
|
shape = (replicas, 4)
|
2020-05-10 14:25:18 +03:00
|
|
|
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
2020-05-08 14:00:34 -07:00
|
|
|
def sum_helper(a):
|
2020-05-10 14:25:18 +03:00
|
|
|
return np.broadcast_to(a.sum(0, keepdims=True),
|
2020-05-08 14:00:34 -07:00
|
|
|
(replicas // 2, x.shape[1]))
|
|
|
|
expected_psum_1 = sum_helper(x[:replicas // 2])
|
|
|
|
expected_psum_2 = sum_helper(x[replicas // 2:])
|
2020-05-10 14:25:18 +03:00
|
|
|
expected_psum = np.concatenate([expected_psum_1, expected_psum_2], 0)
|
2020-05-08 14:00:34 -07:00
|
|
|
expected = x - expected_psum
|
|
|
|
|
|
|
|
ans = f(x)
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2020-09-09 13:02:45 +01:00
|
|
|
def testGatherReplicaGroups(self):
|
2021-09-23 06:33:25 -07:00
|
|
|
replicas = jax.device_count()
|
2020-09-09 13:02:45 +01:00
|
|
|
if replicas % 2 != 0:
|
|
|
|
raise SkipTest("Test expected an even number of devices greater than 1.")
|
|
|
|
|
2021-06-02 14:02:47 -07:00
|
|
|
axis_index_groups = np.arange(replicas, dtype=np.int32)
|
|
|
|
axis_index_groups = axis_index_groups.reshape((replicas // 2, 2)).T
|
|
|
|
axis_index_groups = axis_index_groups.tolist()
|
2020-09-09 13:02:45 +01:00
|
|
|
|
|
|
|
f = lambda x: lax.all_gather(x, 'i', axis_index_groups=axis_index_groups)
|
2021-08-04 14:46:21 -07:00
|
|
|
f = self.pmap(f, 'i')
|
2020-09-09 13:02:45 +01:00
|
|
|
|
|
|
|
shape = (replicas, 4)
|
|
|
|
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
|
|
|
|
|
|
|
ans = f(x)
|
|
|
|
|
2021-06-02 14:02:47 -07:00
|
|
|
group_1_result = x[0::2]
|
|
|
|
group_2_result = x[1::2]
|
|
|
|
expected = np.empty((replicas, replicas // 2, x.shape[1]))
|
|
|
|
expected[0::2] = group_1_result
|
|
|
|
expected[1::2] = group_2_result
|
2020-09-09 13:02:45 +01:00
|
|
|
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
def testGatherReplicaGroupsInterleaved(self):
|
2021-09-23 06:33:25 -07:00
|
|
|
replicas = jax.device_count()
|
2020-09-09 13:02:45 +01:00
|
|
|
if replicas % 2 != 0:
|
|
|
|
raise SkipTest("Test expected an even number of devices greater than 1.")
|
|
|
|
|
|
|
|
indexes = np.arange(replicas)
|
|
|
|
indexes = np.concatenate([indexes[::2], indexes[1::2]])
|
|
|
|
axis_index_groups = indexes.reshape(2, replicas // 2).tolist()
|
|
|
|
|
|
|
|
f = lambda x: lax.all_gather(x, 'i', axis_index_groups=axis_index_groups)
|
2021-08-04 14:46:21 -07:00
|
|
|
f = self.pmap(f, 'i')
|
2020-09-09 13:02:45 +01:00
|
|
|
|
|
|
|
shape = (replicas, 4)
|
|
|
|
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
|
|
|
|
|
|
|
ans = f(x)
|
|
|
|
|
|
|
|
expected = np.zeros((replicas, replicas // 2, x.shape[1]))
|
|
|
|
expected[::2] = x[::2]
|
|
|
|
expected[1::2] = x[1::2]
|
|
|
|
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2021-03-05 17:59:16 +00:00
|
|
|
@ignore_slow_all_to_all_warning()
|
2021-01-13 10:33:03 +00:00
|
|
|
def testGradOfGather(self):
|
2021-08-04 14:46:21 -07:00
|
|
|
@partial(self.pmap, axis_name='i')
|
2021-01-13 10:33:03 +00:00
|
|
|
def f(x):
|
|
|
|
return lax.all_gather(x, axis_name='i')
|
|
|
|
|
|
|
|
shape = (jax.device_count(), 4)
|
|
|
|
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
|
|
|
jtu.check_grads(f, (x,), 2, ["fwd", "rev"], 1e-2, 1e-2, eps=1.)
|
|
|
|
|
2020-05-08 14:00:34 -07:00
|
|
|
def testNestedPmapReplicaGroups(self):
|
2021-09-23 06:33:25 -07:00
|
|
|
replicas = jax.device_count()
|
2020-05-08 14:00:34 -07:00
|
|
|
if replicas % 4 != 0:
|
|
|
|
raise SkipTest
|
2020-05-10 14:25:18 +03:00
|
|
|
axis_index_groups = np.arange(replicas // 2).reshape(
|
2020-05-08 14:00:34 -07:00
|
|
|
2, replicas // 4).tolist()
|
|
|
|
f = lambda x: x - lax.psum(x, 'i', axis_index_groups=axis_index_groups)
|
2021-08-04 14:46:21 -07:00
|
|
|
f1 = self.pmap(self.pmap(f, 'i'), 'j')
|
|
|
|
f2 = self.pmap(lambda x: self.pmap(f, 'i')(x) + 1., 'j') # "imperfectly nested" case
|
|
|
|
f3 = self.pmap(self.pmap(f, 'j'), 'i')
|
2020-05-08 14:00:34 -07:00
|
|
|
|
|
|
|
shape = (2, replicas // 2, 4)
|
2020-05-10 14:25:18 +03:00
|
|
|
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
2020-05-08 14:00:34 -07:00
|
|
|
def sum_helper_f1(a):
|
2020-05-10 14:25:18 +03:00
|
|
|
return np.broadcast_to(a.sum(1, keepdims=True),
|
2020-05-08 14:00:34 -07:00
|
|
|
(shape[0], shape[1] // 2, shape[2]))
|
|
|
|
expected_psum_1 = sum_helper_f1(x[:, :replicas // 4])
|
|
|
|
expected_psum_2 = sum_helper_f1(x[:, replicas // 4:])
|
2020-05-10 14:25:18 +03:00
|
|
|
expected_psum = np.concatenate([expected_psum_1, expected_psum_2], 1)
|
2020-05-08 14:00:34 -07:00
|
|
|
expected = x - expected_psum
|
|
|
|
ans = f1(x)
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(ans, expected)
|
2020-05-08 14:00:34 -07:00
|
|
|
|
|
|
|
expected = x - expected_psum + 1.
|
|
|
|
ans = f2(x)
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(ans, expected)
|
2020-05-08 14:00:34 -07:00
|
|
|
|
|
|
|
shape = (replicas // 2, 2, 4)
|
2020-05-10 14:25:18 +03:00
|
|
|
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
2020-05-08 14:00:34 -07:00
|
|
|
def sum_helper_f3(a):
|
2020-05-10 14:25:18 +03:00
|
|
|
return np.broadcast_to(a.sum(0, keepdims=True),
|
2020-05-08 14:00:34 -07:00
|
|
|
(shape[0] // 2, shape[1], shape[2]))
|
|
|
|
expected_psum_1 = sum_helper_f3(x[:replicas // 4])
|
|
|
|
expected_psum_2 = sum_helper_f3(x[replicas // 4:])
|
2020-05-10 14:25:18 +03:00
|
|
|
expected_psum = np.concatenate([expected_psum_1, expected_psum_2], 0)
|
2020-05-08 14:00:34 -07:00
|
|
|
expected = x - expected_psum
|
|
|
|
ans = f3(x)
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(ans, expected)
|
2020-05-08 14:00:34 -07:00
|
|
|
|
2019-11-13 21:10:16 -08:00
|
|
|
def testAxisGroups(self):
|
2020-11-19 11:36:35 +00:00
|
|
|
axis_env = xla.AxisEnv(8, ('i', 'j'), (4, 2))
|
2019-11-13 21:10:16 -08:00
|
|
|
groups = xla.axis_groups(axis_env, 'i')
|
2019-04-01 17:56:23 -07:00
|
|
|
self.assertEqual(groups, ((0, 2, 4, 6), (1, 3, 5, 7)))
|
|
|
|
|
2019-11-13 21:10:16 -08:00
|
|
|
groups = xla.axis_groups(axis_env, 'j')
|
2019-04-01 17:56:23 -07:00
|
|
|
self.assertEqual(groups, ((0, 1), (2, 3), (4, 5), (6, 7)))
|
|
|
|
|
2019-11-13 21:10:16 -08:00
|
|
|
groups = xla.axis_groups(axis_env, ('i', 'j'))
|
2019-04-01 17:56:23 -07:00
|
|
|
self.assertEqual(groups, ((0, 1, 2, 3, 4, 5, 6, 7,),))
|
|
|
|
|
2019-11-13 21:10:16 -08:00
|
|
|
groups = xla.axis_groups(axis_env, ('j', 'i'))
|
2019-04-01 17:56:23 -07:00
|
|
|
self.assertEqual(len(groups), 1)
|
|
|
|
self.assertEqual((tuple(sorted(groups[0])),),
|
|
|
|
((0, 1, 2, 3, 4, 5, 6, 7,),)) # order doesn't matter
|
|
|
|
|
2019-05-09 15:46:34 -07:00
|
|
|
def testCollectivePermute(self):
|
2021-09-23 06:33:25 -07:00
|
|
|
device_count = jax.device_count()
|
2019-05-10 14:24:15 -07:00
|
|
|
rotation = [(i, (i + 1) % device_count) for i in range(device_count)]
|
2019-05-10 12:27:14 -07:00
|
|
|
f = lambda x: lax.ppermute(x, perm=rotation, axis_name='i')
|
2021-08-04 14:46:21 -07:00
|
|
|
f = self.pmap(f, 'i')
|
2019-05-09 15:46:34 -07:00
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
x = jnp.arange(4 * device_count).reshape((device_count, 4))
|
2019-05-09 15:46:34 -07:00
|
|
|
ans = f(x)
|
2020-05-05 14:59:16 -04:00
|
|
|
expected = np.roll(x, shift=1, axis=0)
|
2019-05-09 15:46:34 -07:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2021-01-06 10:15:28 +00:00
|
|
|
@jtu.skip_on_devices("cpu")
|
2019-05-31 14:04:04 -07:00
|
|
|
def testCollectivePermuteGrad(self):
|
2021-09-23 06:33:25 -07:00
|
|
|
device_count = jax.device_count()
|
2019-05-31 14:04:04 -07:00
|
|
|
shift_right = [(i, (i + 1)) for i in range(device_count - 1)]
|
|
|
|
f = lambda x: lax.ppermute(x, perm=shift_right, axis_name='i')
|
2020-05-05 14:59:16 -04:00
|
|
|
y = np.pi + np.arange(device_count, dtype=np.float32)
|
2021-08-04 14:46:21 -07:00
|
|
|
g = lambda x: jnp.sum(y * self.pmap(f, 'i')(x))
|
2019-05-31 14:04:04 -07:00
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
x = np.arange(device_count, dtype=np.float32)
|
2019-05-31 14:04:04 -07:00
|
|
|
ans = grad(g)(x)
|
2020-05-05 14:59:16 -04:00
|
|
|
expected = np.concatenate([np.pi + np.arange(1, device_count), [0]])
|
2019-05-31 14:11:38 -07:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
def testCollectivePermuteCyclicGrad(self):
|
2021-09-23 06:33:25 -07:00
|
|
|
device_count = jax.device_count()
|
2019-05-31 14:11:38 -07:00
|
|
|
shift_right = [(i, (i + 1) % device_count) for i in range(device_count)]
|
|
|
|
f = lambda x: lax.ppermute(x, perm=shift_right, axis_name='i')
|
2020-05-05 14:59:16 -04:00
|
|
|
y = np.pi + np.arange(device_count, dtype=np.float32)
|
2021-08-04 14:46:21 -07:00
|
|
|
g = lambda x: jnp.sum(y * self.pmap(f, 'i')(x))
|
2019-05-31 14:11:38 -07:00
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
x = np.arange(device_count, dtype=np.float32)
|
2020-07-07 00:30:08 -07:00
|
|
|
|
|
|
|
ans = grad(g)(x)
|
|
|
|
expected = np.roll(np.pi + np.arange(device_count), -1)
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
jtu.check_grads(g, (x,), 2, ["fwd", "rev"], 1e-2, 1e-2)
|
|
|
|
|
2020-01-10 16:49:08 -08:00
|
|
|
def testCollectivePermuteCyclicWithPShuffle(self):
|
2021-09-23 06:33:25 -07:00
|
|
|
device_count = jax.device_count()
|
2020-05-05 14:59:16 -04:00
|
|
|
values = np.arange(device_count)
|
2020-01-10 16:49:08 -08:00
|
|
|
shift_right = [(i - 1) % device_count for i in range(device_count)]
|
|
|
|
f = lambda x: lax.pshuffle(x, perm=shift_right, axis_name='i')
|
2020-07-07 00:30:08 -07:00
|
|
|
expected = np.roll(values, 1)
|
2021-08-04 14:46:21 -07:00
|
|
|
ans = np.asarray(self.pmap(f, "i")(values))
|
2020-01-10 16:49:08 -08:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
def testPShuffleWithBadPerm(self):
|
2021-09-23 06:33:25 -07:00
|
|
|
device_count = jax.device_count()
|
2020-01-10 16:49:08 -08:00
|
|
|
bad_perm = list(range(device_count))
|
|
|
|
bad_perm[0] = 1
|
|
|
|
f = lambda x: lax.pshuffle(x, perm=bad_perm, axis_name='i')
|
2021-08-04 14:46:21 -07:00
|
|
|
g = lambda: self.pmap(f, "i")(np.arange(device_count))
|
2020-01-10 16:49:08 -08:00
|
|
|
self.assertRaisesRegex(
|
2020-07-07 13:21:44 -07:00
|
|
|
ValueError,
|
2020-07-07 13:19:19 -07:00
|
|
|
"`perm` does not represent a permutation: \\[1.*\\]", g)
|
2020-01-10 16:49:08 -08:00
|
|
|
|
2019-11-16 14:40:25 -08:00
|
|
|
def testPpermuteWithZipObject(self):
|
|
|
|
# https://github.com/google/jax/issues/1703
|
2021-09-23 06:33:25 -07:00
|
|
|
num_devices = jax.device_count()
|
2019-11-15 14:33:39 -08:00
|
|
|
perm = [num_devices - 1] + list(range(num_devices - 1))
|
2021-08-04 14:46:21 -07:00
|
|
|
f = self.pmap(lambda x: lax.ppermute(x, "i", zip(perm, range(num_devices))), "i")
|
2020-05-05 14:59:16 -04:00
|
|
|
result = f(jnp.arange(num_devices, dtype=jnp.float32))
|
|
|
|
expected = jnp.asarray(perm, dtype=jnp.float32)
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(result, expected)
|
2019-11-15 14:33:39 -08:00
|
|
|
|
2019-05-09 15:46:34 -07:00
|
|
|
def testRule30(self):
|
2019-05-10 12:27:14 -07:00
|
|
|
# This is a test of collective_permute implementing a simple halo exchange
|
|
|
|
# to run a rule 30 simulation: https://en.wikipedia.org/wiki/Rule_30
|
|
|
|
# Halo exchange should be useful in spatially-sharded convolutions and in
|
|
|
|
# other simulations.
|
2021-09-23 06:33:25 -07:00
|
|
|
device_count = jax.device_count()
|
2019-05-10 12:27:14 -07:00
|
|
|
|
|
|
|
def send_right(x, axis_name):
|
|
|
|
left_perm = [(i, (i + 1) % device_count) for i in range(device_count)]
|
|
|
|
return lax.ppermute(x, perm=left_perm, axis_name=axis_name)
|
|
|
|
|
|
|
|
def send_left(x, axis_name):
|
|
|
|
left_perm = [((i + 1) % device_count, i) for i in range(device_count)]
|
|
|
|
return lax.ppermute(x, perm=left_perm, axis_name=axis_name)
|
2019-05-09 15:46:34 -07:00
|
|
|
|
|
|
|
def update_board(board):
|
|
|
|
left = board[:-2]
|
|
|
|
right = board[2:]
|
|
|
|
center = board[1:-1]
|
|
|
|
return lax.bitwise_xor(left, lax.bitwise_or(center, right))
|
|
|
|
|
2021-08-04 14:46:21 -07:00
|
|
|
@partial(self.pmap, axis_name='i')
|
2019-05-09 15:46:34 -07:00
|
|
|
def step(board_slice):
|
|
|
|
left, right = board_slice[:1], board_slice[-1:]
|
2019-05-10 12:27:14 -07:00
|
|
|
right, left = send_left(left, 'i'), send_right(right, 'i')
|
2020-05-05 14:59:16 -04:00
|
|
|
enlarged_board_slice = jnp.concatenate([left, board_slice, right])
|
2019-05-09 15:46:34 -07:00
|
|
|
return update_board(enlarged_board_slice)
|
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
board = np.zeros(40, dtype=bool)
|
2019-05-09 15:46:34 -07:00
|
|
|
board[board.shape[0] // 2] = True
|
|
|
|
reshaped_board = board.reshape((device_count, -1))
|
|
|
|
|
|
|
|
boards = []
|
|
|
|
def print_board(board):
|
|
|
|
boards.append(''.join('*' if x else ' ' for x in board.ravel()))
|
|
|
|
|
|
|
|
print_board(reshaped_board)
|
|
|
|
for _ in range(20):
|
|
|
|
reshaped_board = step(reshaped_board)
|
|
|
|
print_board(reshaped_board)
|
|
|
|
|
|
|
|
ans = '\n'.join(boards)
|
|
|
|
expected = '\n'.join((
|
|
|
|
' * ',
|
|
|
|
' *** ',
|
|
|
|
' ** * ',
|
|
|
|
' ** **** ',
|
|
|
|
' ** * * ',
|
|
|
|
' ** **** *** ',
|
|
|
|
' ** * * * ',
|
|
|
|
' ** **** ****** ',
|
|
|
|
' ** * *** * ',
|
|
|
|
' ** **** ** * *** ',
|
|
|
|
' ** * * **** ** * ',
|
|
|
|
' ** **** ** * * **** ',
|
|
|
|
' ** * *** ** ** * * ',
|
|
|
|
' ** **** ** *** *** ** *** ',
|
|
|
|
' ** * * *** * *** * * ',
|
|
|
|
' ** **** ** * * ***** ******* ',
|
|
|
|
' ** * *** **** * *** * ',
|
|
|
|
' ** **** ** *** ** ** * *** ',
|
|
|
|
' ** * * *** * ** *** **** ** * ',
|
|
|
|
' ** **** ** * ****** * * *** ****',
|
|
|
|
' * * *** **** **** *** ** * ',
|
|
|
|
))
|
|
|
|
|
|
|
|
print(ans)
|
|
|
|
self.assertEqual(ans, expected)
|
|
|
|
|
|
|
|
def testReduceMax(self):
|
2021-08-04 14:46:21 -07:00
|
|
|
f = self.pmap(lambda x: x - lax.pmax(x, 'i'), axis_name='i')
|
2019-05-09 15:46:34 -07:00
|
|
|
|
2021-09-23 06:33:25 -07:00
|
|
|
shape = (jax.device_count(), 4)
|
2020-05-05 14:59:16 -04:00
|
|
|
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
|
|
|
expected = x - np.max(x, 0)
|
2019-05-09 15:46:34 -07:00
|
|
|
|
|
|
|
ans = f(x)
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
def testReduceMin(self):
|
2021-08-04 14:46:21 -07:00
|
|
|
f = self.pmap(lambda x: x - lax.pmin(x, 'i'), axis_name='i')
|
2019-05-09 15:46:34 -07:00
|
|
|
|
2021-09-23 06:33:25 -07:00
|
|
|
shape = (jax.device_count(), 4)
|
2020-05-05 14:59:16 -04:00
|
|
|
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
|
|
|
expected = x - np.min(x, 0)
|
2019-05-09 15:46:34 -07:00
|
|
|
|
|
|
|
ans = f(x)
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2019-05-17 09:08:08 -07:00
|
|
|
def testDeviceCountError(self):
|
2021-09-23 06:33:25 -07:00
|
|
|
device_count = jax.device_count()
|
2019-05-17 09:08:08 -07:00
|
|
|
|
2021-08-04 14:46:21 -07:00
|
|
|
f = self.pmap(lambda x: x)
|
2020-05-05 14:59:16 -04:00
|
|
|
x = jnp.arange(device_count + 1)
|
2019-11-11 07:02:36 -08:00
|
|
|
self.assertRaisesRegex(ValueError, ".*requires.*replicas", lambda: f(x))
|
2019-05-17 09:08:08 -07:00
|
|
|
|
2021-08-04 14:46:21 -07:00
|
|
|
f = self.pmap(lambda x: x)
|
2020-05-05 14:59:16 -04:00
|
|
|
x = np.ones((device_count + 1, 10))
|
2019-11-11 07:02:36 -08:00
|
|
|
self.assertRaisesRegex(ValueError, ".*requires.*replicas", lambda: f(x))
|
2019-05-17 09:08:08 -07:00
|
|
|
|
2021-08-04 14:46:21 -07:00
|
|
|
f = self.pmap(lambda x: self.pmap(lambda x: x)(x))
|
2020-05-05 14:59:16 -04:00
|
|
|
x = np.ones((device_count, 2, 10))
|
2019-11-11 07:02:36 -08:00
|
|
|
self.assertRaisesRegex(ValueError, ".*requires.*replicas", lambda: f(x))
|
2019-05-17 09:08:08 -07:00
|
|
|
|
2019-05-29 10:39:51 -07:00
|
|
|
def testPmapConstant(self):
|
2021-09-23 06:33:25 -07:00
|
|
|
device_count = jax.device_count()
|
2021-08-04 14:46:21 -07:00
|
|
|
f = self.pmap(lambda x: 3)
|
2020-05-05 14:59:16 -04:00
|
|
|
x = jnp.arange(device_count)
|
2020-12-02 14:13:05 +00:00
|
|
|
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
|
2019-12-19 11:19:58 -08:00
|
|
|
ans = f(x)
|
2020-07-30 12:59:36 -07:00
|
|
|
# self.assertEqual(count[0], 0) # TODO(mattjj): fix this
|
2020-05-05 14:59:16 -04:00
|
|
|
expected = np.repeat(3, device_count)
|
2019-05-29 10:39:51 -07:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2021-08-04 14:46:21 -07:00
|
|
|
f = self.pmap(lambda x: (x, 3))
|
2020-05-05 14:59:16 -04:00
|
|
|
x = np.arange(device_count)
|
2020-12-02 14:13:05 +00:00
|
|
|
with jtu.assert_num_jit_and_pmap_compilations(1):
|
2019-12-19 11:19:58 -08:00
|
|
|
_, ans = f(x)
|
2019-12-17 16:22:55 -08:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
def testPmapConstantDevices(self):
|
2021-09-23 06:33:25 -07:00
|
|
|
if jax.device_count() == 1:
|
2019-12-17 16:22:55 -08:00
|
|
|
raise SkipTest("this test requires multiple devices")
|
|
|
|
|
2021-09-23 06:33:25 -07:00
|
|
|
devices = jax.devices()[:-1]
|
2019-12-17 16:22:55 -08:00
|
|
|
shuffle(devices)
|
2021-08-04 14:46:21 -07:00
|
|
|
f = self.pmap(lambda x: 3, devices=devices)
|
2020-05-05 14:59:16 -04:00
|
|
|
x = jnp.arange(len(devices))
|
2020-07-30 12:59:36 -07:00
|
|
|
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
|
2019-12-19 11:19:58 -08:00
|
|
|
ans = f(x)
|
2020-07-30 12:59:36 -07:00
|
|
|
# self.assertEqual(count[0], 0) # TODO(mattjj): don't compile for constants
|
2020-05-05 14:59:16 -04:00
|
|
|
expected = np.repeat(3, len(devices))
|
2019-12-17 16:22:55 -08:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
# Test that 'ans' was properly replicated across devices.
|
|
|
|
self.assertEqual([b.device() for b in ans.device_buffers], devices)
|
|
|
|
|
|
|
|
def testPmapConstantError(self):
|
2021-09-23 06:33:25 -07:00
|
|
|
device_count = jax.device_count()
|
2021-08-04 14:46:21 -07:00
|
|
|
f = self.pmap(lambda x: 3)
|
2020-05-05 14:59:16 -04:00
|
|
|
x = jnp.arange(device_count + 1)
|
2021-03-29 13:58:04 -07:00
|
|
|
self.assertRaisesRegex(
|
|
|
|
ValueError,
|
|
|
|
(r"compiling computation that requires \d+ logical devices, "
|
|
|
|
r"but only \d+ XLA devices are available .*"),
|
|
|
|
lambda: f(x))
|
2019-12-17 16:22:55 -08:00
|
|
|
|
2021-03-29 13:58:04 -07:00
|
|
|
# TODO(mattjj): test error message with explicit devices
|
2021-09-23 06:33:25 -07:00
|
|
|
# f = pmap(lambda x: 3, devices=[jax.devices()[0]])
|
2021-03-29 13:58:04 -07:00
|
|
|
# x = jnp.arange(2)
|
|
|
|
# self.assertRaisesRegex(
|
|
|
|
# ValueError, r"Cannot replicate across \d+ replicas because only \d+ "
|
|
|
|
# r"local devices are available.", lambda: f(x))
|
2019-12-17 16:22:55 -08:00
|
|
|
|
|
|
|
def testNestedPmapConstant(self):
|
2021-09-23 06:33:25 -07:00
|
|
|
if jax.device_count() == 1:
|
2019-12-17 16:22:55 -08:00
|
|
|
raise SkipTest("this test requires multiple devices")
|
|
|
|
|
2021-08-04 14:46:21 -07:00
|
|
|
f = self.pmap(self.pmap(lambda x: 3))
|
2021-09-23 06:33:25 -07:00
|
|
|
shape = (2, jax.device_count() // 2, 3)
|
2020-05-05 14:59:16 -04:00
|
|
|
x = jnp.arange(prod(shape)).reshape(shape)
|
2020-07-30 12:59:36 -07:00
|
|
|
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
|
2019-12-19 11:19:58 -08:00
|
|
|
ans = f(x)
|
2020-07-30 12:59:36 -07:00
|
|
|
# self.assertEqual(count[0], 0) # TODO(mattjj): don't compile for constants
|
2020-05-05 14:59:16 -04:00
|
|
|
expected = 3 * np.ones(shape[:2])
|
2019-12-17 16:22:55 -08:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
# Test that 'ans' was properly replicated across devices.
|
2021-08-04 14:46:21 -07:00
|
|
|
expected_sharded = self.pmap(self.pmap(lambda x: x))(expected)
|
2019-12-17 16:22:55 -08:00
|
|
|
self.assertEqual([b.device() for b in ans.device_buffers],
|
|
|
|
[b.device() for b in expected_sharded.device_buffers])
|
|
|
|
|
2021-08-04 14:46:21 -07:00
|
|
|
f = self.pmap(self.pmap(lambda x: (x, 3)))
|
2019-12-17 16:22:55 -08:00
|
|
|
x_sharded, ans = f(x)
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
self.assertEqual([b.device() for b in ans.device_buffers],
|
|
|
|
[b.device() for b in x_sharded.device_buffers])
|
|
|
|
|
2021-10-04 17:54:18 -07:00
|
|
|
@unittest.skip("Nested pmaps with devices not yet implemented")
|
2019-12-17 16:22:55 -08:00
|
|
|
def testNestedPmapConstantDevices(self):
|
2021-09-23 06:33:25 -07:00
|
|
|
if jax.device_count() < 6:
|
2019-12-17 16:22:55 -08:00
|
|
|
raise SkipTest("this test requires >= 6 devices")
|
|
|
|
|
2021-09-23 06:33:25 -07:00
|
|
|
devices = jax.devices()[:-2]
|
2019-12-17 16:22:55 -08:00
|
|
|
shuffle(devices)
|
2021-08-04 14:46:21 -07:00
|
|
|
f = self.pmap(self.pmap(lambda x: 3), devices=devices)
|
2019-12-17 16:22:55 -08:00
|
|
|
shape = (2, len(devices) // 2, 3)
|
2020-05-05 14:59:16 -04:00
|
|
|
x = jnp.arange(prod(shape)).reshape(shape)
|
2020-07-30 12:59:36 -07:00
|
|
|
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
|
2019-12-19 11:19:58 -08:00
|
|
|
ans = f(x)
|
2020-07-30 12:59:36 -07:00
|
|
|
# self.assertEqual(count[0], 0) # TODO(mattjj): don't compile for constants
|
2020-05-05 14:59:16 -04:00
|
|
|
expected = 3 * np.ones(shape[:2])
|
2019-12-17 16:22:55 -08:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
# Test that 'ans' was properly replicated across devices.
|
2021-08-04 14:46:21 -07:00
|
|
|
expected_sharded = self.pmap(self.pmap(lambda x: x), devices=devices)(expected)
|
2019-12-17 16:22:55 -08:00
|
|
|
self.assertEqual([b.device() for b in ans.device_buffers],
|
|
|
|
[b.device() for b in expected_sharded.device_buffers])
|
|
|
|
|
|
|
|
def testNestedPmapConstantError(self):
|
2021-08-04 14:46:21 -07:00
|
|
|
f = self.pmap(self.pmap(lambda x: 3))
|
2021-09-23 06:33:25 -07:00
|
|
|
shape = (2, jax.device_count() // 2 + 1, 3)
|
2020-05-05 14:59:16 -04:00
|
|
|
x = jnp.arange(prod(shape)).reshape(shape)
|
2021-03-29 13:58:04 -07:00
|
|
|
self.assertRaisesRegex(
|
|
|
|
ValueError,
|
|
|
|
(r"compiling computation that requires \d+ logical devices, "
|
|
|
|
r"but only \d+ XLA devices are available .*"),
|
|
|
|
lambda: f(x))
|
|
|
|
|
|
|
|
# TODO(mattjj): check error message with explicit devices
|
2021-09-23 06:33:25 -07:00
|
|
|
# if jax.device_count() > 1:
|
|
|
|
# f = pmap(pmap(lambda x: 3), devices=jax.devices()[:-1])
|
|
|
|
# shape = (2, jax.device_count() // 2, 3)
|
2021-03-29 13:58:04 -07:00
|
|
|
# x = jnp.arange(prod(shape)).reshape(shape)
|
|
|
|
# self.assertRaisesRegex(
|
|
|
|
# ValueError,
|
|
|
|
# (r"compiling computation that requires \d+ replicas, "
|
|
|
|
# r"but only \d+ XLA devices are available"),
|
|
|
|
# lambda: f(x))
|
2020-07-30 12:59:36 -07:00
|
|
|
|
2019-05-29 10:39:51 -07:00
|
|
|
def testCollectiveConstant(self):
|
2021-09-23 06:33:25 -07:00
|
|
|
device_count = jax.device_count()
|
2021-08-04 14:46:21 -07:00
|
|
|
f = self.pmap(lambda x: lax.psum(1, 'i'), 'i')
|
2020-05-05 14:59:16 -04:00
|
|
|
x = jnp.arange(device_count)
|
2019-05-29 10:39:51 -07:00
|
|
|
ans = f(x)
|
2020-05-05 14:59:16 -04:00
|
|
|
expected = np.repeat(device_count, device_count)
|
2019-05-29 10:39:51 -07:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
|
|
|
def testCollectiveConstantNested(self):
|
2021-09-23 06:33:25 -07:00
|
|
|
device_count = jax.device_count()
|
2019-05-29 10:39:51 -07:00
|
|
|
|
2021-08-04 14:46:21 -07:00
|
|
|
@partial(self.pmap, axis_name='i')
|
2019-05-29 10:39:51 -07:00
|
|
|
def f(x):
|
2021-08-04 14:46:21 -07:00
|
|
|
@partial(self.pmap, axis_name='j')
|
2019-05-29 10:39:51 -07:00
|
|
|
def g(y):
|
|
|
|
a = lax.psum(1, 'i')
|
|
|
|
b = lax.psum(1, 'j')
|
|
|
|
c = lax.psum(1, ('i', 'j'))
|
|
|
|
return a, b, c
|
|
|
|
return g(x)
|
|
|
|
|
|
|
|
shape = (device_count, 1, 4)
|
2020-05-05 14:59:16 -04:00
|
|
|
x = jnp.arange(prod(shape)).reshape(shape)
|
2019-05-29 10:39:51 -07:00
|
|
|
a, b, c = f(x)
|
|
|
|
|
|
|
|
self.assertEqual(a.shape, shape[:-1])
|
|
|
|
self.assertEqual(b.shape, shape[:-1])
|
|
|
|
self.assertEqual(c.shape, shape[:-1])
|
|
|
|
|
|
|
|
self.assertEqual(a.ravel()[0], device_count)
|
|
|
|
self.assertEqual(b.ravel()[0], 1)
|
|
|
|
self.assertEqual(c.ravel()[0], device_count * 1)
|
|
|
|
|
|
|
|
def testAxisIndex(self):
|
2021-09-23 06:33:25 -07:00
|
|
|
device_count = jax.device_count()
|
2021-08-04 14:46:21 -07:00
|
|
|
f = self.pmap(lambda x: x + lax.axis_index('i'), 'i')
|
2020-05-05 14:59:16 -04:00
|
|
|
x = jnp.ones(device_count)
|
2019-05-29 10:39:51 -07:00
|
|
|
ans = f(x)
|
2020-05-05 14:59:16 -04:00
|
|
|
expected = 1 + np.arange(device_count)
|
2019-05-29 10:39:51 -07:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2020-09-22 13:04:53 +00:00
|
|
|
def testAxisIndexNestedPmap(self):
|
2021-09-23 06:33:25 -07:00
|
|
|
device_count = jax.device_count()
|
2020-09-22 13:04:53 +00:00
|
|
|
if device_count < 4:
|
|
|
|
raise SkipTest("test requires at least four devices")
|
2021-08-04 14:46:21 -07:00
|
|
|
f = lambda axis: self.pmap(self.pmap(lambda x: x + lax.axis_index(axis), 'j'), 'i')
|
2020-09-22 13:04:53 +00:00
|
|
|
x = jnp.ones((2, 2))
|
|
|
|
expected_j = np.broadcast_to(1 + np.arange(2), (2, 2))
|
|
|
|
self.assertAllClose(f('j')(x), expected_j, check_dtypes=False)
|
|
|
|
self.assertAllClose(f('i')(x), expected_j.T, check_dtypes=False)
|
|
|
|
|
|
|
|
def testAxisIndexNd(self):
|
2021-09-23 06:33:25 -07:00
|
|
|
device_count = jax.device_count()
|
2020-09-22 13:04:53 +00:00
|
|
|
if device_count < 4:
|
|
|
|
raise SkipTest("test requires at least four devices")
|
2021-08-04 14:46:21 -07:00
|
|
|
f = lambda axes: self.pmap(self.pmap(lambda x: x + lax.axis_index(axes), 'j'), 'i')
|
2020-09-22 13:04:53 +00:00
|
|
|
x = jnp.ones((2, 2))
|
|
|
|
expected = 1 + np.arange(4).reshape((2, 2))
|
|
|
|
self.assertAllClose(f(('i', 'j'))(x), expected, check_dtypes=False)
|
|
|
|
self.assertAllClose(f(('j', 'i'))(x), expected.T, check_dtypes=False)
|
|
|
|
|
2020-09-08 18:04:11 +00:00
|
|
|
def testAxisIndexInInitialStyle(self):
|
2021-08-04 14:46:21 -07:00
|
|
|
@partial(self.pmap, axis_name='i')
|
2020-09-08 18:04:11 +00:00
|
|
|
def f(x):
|
|
|
|
def body(carry, i):
|
|
|
|
return carry + i + lax.axis_index('i'), None
|
|
|
|
return lax.scan(body, 0, x)[0]
|
2021-09-23 06:33:25 -07:00
|
|
|
device_count = jax.device_count()
|
2020-09-08 18:04:11 +00:00
|
|
|
shape = (device_count, 10)
|
2020-10-07 11:41:22 -07:00
|
|
|
self.assertAllClose(f(jnp.ones(shape, dtype=int)),
|
2020-09-08 18:04:11 +00:00
|
|
|
(np.arange(device_count) + 1) * 10)
|
|
|
|
|
2019-06-04 18:33:52 -07:00
|
|
|
def testVmapOfPmap(self):
|
2021-09-23 06:33:25 -07:00
|
|
|
device_count = jax.device_count()
|
2019-06-04 18:33:52 -07:00
|
|
|
f0 = lambda x: x
|
2021-08-04 14:46:21 -07:00
|
|
|
f1 = self.pmap(f0, axis_name='i')
|
2021-12-10 10:32:09 -08:00
|
|
|
ax = self.rng().randn(2, device_count, 50, 60)
|
2019-06-04 18:33:52 -07:00
|
|
|
bx = vmap(f1)(ax)
|
|
|
|
self.assertAllClose(ax, bx, check_dtypes=False)
|
|
|
|
|
2019-09-11 06:01:32 -07:00
|
|
|
def testVmapOfPmap2(self):
|
2021-09-23 06:33:25 -07:00
|
|
|
N_DEVICES = jax.device_count()
|
2019-09-11 06:01:32 -07:00
|
|
|
keys = random.split(random.PRNGKey(1), 13) # [13, 2]
|
|
|
|
|
2021-08-04 14:46:21 -07:00
|
|
|
@self.pmap
|
2019-09-11 06:01:32 -07:00
|
|
|
def g(key):
|
2020-06-02 19:25:47 -07:00
|
|
|
_ = random.normal(key, ())
|
2019-09-11 06:01:32 -07:00
|
|
|
return 0.
|
|
|
|
|
|
|
|
@vmap
|
|
|
|
def s(keys):
|
2021-06-08 11:16:33 -07:00
|
|
|
keys = tree_util.tree_map(
|
|
|
|
lambda x: jnp.broadcast_to(x, (N_DEVICES,) + x.shape),
|
|
|
|
keys)
|
2019-09-11 06:01:32 -07:00
|
|
|
return g(keys)
|
|
|
|
|
2019-09-11 06:22:25 -07:00
|
|
|
ans = s(keys) # doesn't crash
|
|
|
|
self.assertEqual(ans.shape, (13, N_DEVICES))
|
2019-09-11 06:01:32 -07:00
|
|
|
|
2020-06-14 14:45:29 -07:00
|
|
|
def testVmapOfPmap3(self):
|
|
|
|
# https://github.com/google/jax/issues/3399
|
2021-09-23 06:33:25 -07:00
|
|
|
device_count = jax.device_count()
|
2020-06-14 14:45:29 -07:00
|
|
|
if device_count < 2:
|
|
|
|
raise SkipTest("test requires at least two devices")
|
|
|
|
|
|
|
|
def map_version(qs, pts):
|
|
|
|
return jax.lax.map(lambda x: func(x, pts), qs)
|
|
|
|
|
|
|
|
def vmap_version(qs, pts):
|
|
|
|
return jax.vmap(func, in_axes=(0, None))(qs, pts)
|
|
|
|
|
|
|
|
def func(q, pts):
|
2021-08-04 14:46:21 -07:00
|
|
|
q_from_pmap = self.pmap(lambda x, y: y, in_axes=(0, None))(pts, q)
|
2020-06-14 14:45:29 -07:00
|
|
|
return q, q_from_pmap
|
|
|
|
|
|
|
|
pts = jnp.ones(device_count)
|
|
|
|
qs = jnp.asarray(((0,0), (3,3), (2,2)))
|
|
|
|
|
2021-10-13 10:56:21 -04:00
|
|
|
with ignore_jit_of_pmap_warning():
|
|
|
|
_, expected = map_version(qs, pts)
|
2020-06-14 14:45:29 -07:00
|
|
|
_, ans = vmap_version(qs, pts)
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2019-06-04 18:33:52 -07:00
|
|
|
def testVmapOfPmapNonLeadingAxis(self):
|
2021-09-23 06:33:25 -07:00
|
|
|
device_count = jax.device_count()
|
2019-06-04 18:33:52 -07:00
|
|
|
f0 = lambda x: x
|
2021-08-04 14:46:21 -07:00
|
|
|
f1 = self.pmap(f0, axis_name='i')
|
2021-12-10 10:32:09 -08:00
|
|
|
ax = self.rng().randn(device_count, 2, 50, 60)
|
2019-06-04 18:33:52 -07:00
|
|
|
bx = vmap(f1, in_axes=2, out_axes=2)(ax)
|
|
|
|
self.assertAllClose(ax, bx, check_dtypes=False)
|
|
|
|
|
|
|
|
def testVmapOfPmapTuple(self):
|
2021-09-23 06:33:25 -07:00
|
|
|
device_count = jax.device_count()
|
2019-06-04 18:33:52 -07:00
|
|
|
f0 = lambda *x: x
|
2021-08-04 14:46:21 -07:00
|
|
|
f1 = self.pmap(f0, axis_name='i')
|
2019-06-04 18:33:52 -07:00
|
|
|
|
2021-12-10 10:32:09 -08:00
|
|
|
ax = self.rng().randn(device_count, 2, 50, 60)
|
|
|
|
ay = self.rng().randn(device_count, 30, 2)
|
|
|
|
az1 = self.rng().randn(device_count, 20)
|
|
|
|
az2 = self.rng().randn(2, device_count, 20)
|
2019-06-04 18:33:52 -07:00
|
|
|
|
|
|
|
bx, by, bz = vmap(f1, in_axes=(1, 2, (None, 0)), out_axes=(1, 2, 0))(ax, ay, (az1, az2))
|
|
|
|
|
|
|
|
self.assertAllClose(ax, bx, check_dtypes=False)
|
|
|
|
self.assertAllClose(ay, by, check_dtypes=False)
|
|
|
|
|
|
|
|
bz1, bz2 = bz
|
2020-05-05 14:59:16 -04:00
|
|
|
expected_bz1 = np.broadcast_to(az1, (2,) + az1.shape)
|
2019-06-04 18:33:52 -07:00
|
|
|
self.assertAllClose(expected_bz1, bz1, check_dtypes=False)
|
|
|
|
self.assertAllClose(bz2, bz2, check_dtypes=False)
|
|
|
|
|
2021-03-05 17:59:16 +00:00
|
|
|
@ignore_slow_all_to_all_warning()
|
2019-06-08 08:57:34 -07:00
|
|
|
def testPswapaxes(self):
|
2021-09-23 06:33:25 -07:00
|
|
|
device_count = jax.device_count()
|
2019-06-08 08:57:34 -07:00
|
|
|
shape = (device_count, 3, device_count, 5)
|
2020-05-05 14:59:16 -04:00
|
|
|
x = np.arange(prod(shape)).reshape(shape)
|
2019-06-08 08:57:34 -07:00
|
|
|
|
2021-08-04 14:46:21 -07:00
|
|
|
ans = self.pmap(lambda x: lax.pswapaxes(x, 'i', 1), axis_name='i')(x)
|
2020-05-05 14:59:16 -04:00
|
|
|
expected = np.swapaxes(x, 0, 2)
|
2019-06-08 08:57:34 -07:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2021-03-05 17:59:16 +00:00
|
|
|
@ignore_slow_all_to_all_warning()
|
2020-07-15 00:45:49 +02:00
|
|
|
def testGradOfPswapaxes(self):
|
2021-09-23 06:33:25 -07:00
|
|
|
device_count = jax.device_count()
|
2020-07-15 00:45:49 +02:00
|
|
|
shape = (device_count, 1, device_count)
|
|
|
|
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
|
|
|
w = np.arange(device_count, dtype=np.float32)
|
|
|
|
|
2021-08-04 14:46:21 -07:00
|
|
|
@partial(self.pmap, axis_name='i')
|
2020-07-15 00:45:49 +02:00
|
|
|
def f(x, w):
|
|
|
|
g = lambda x: jnp.sum(lax.pswapaxes(x, 'i', 1) * w)
|
|
|
|
return grad(g)(x)
|
|
|
|
|
|
|
|
ans = f(x, w)
|
|
|
|
expected = np.tile(w, reps=device_count).reshape(shape)
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2021-03-05 12:24:56 +00:00
|
|
|
@ignore_slow_all_to_all_warning()
|
2020-12-01 02:06:15 +00:00
|
|
|
def testAllToAllReplicaGroups(self):
|
|
|
|
# If num_devices = 4, these would be the inputs/outputs:
|
|
|
|
# input = [[0, 1], [2, 3], [4, 5], [6, 7]]
|
2021-06-02 14:02:47 -07:00
|
|
|
# axis_index_groups = [[0, 2], [1, 3]]
|
|
|
|
# output = [[0, 4], [2, 6], [1, 5], [3, 7]]
|
2020-12-01 02:06:15 +00:00
|
|
|
#
|
2021-07-24 15:25:13 +07:00
|
|
|
# This is essentially like splitting the number of rows in the input in two
|
2020-12-01 02:06:15 +00:00
|
|
|
# groups of rows, and swaping the two inner axes (axis=1 and axis=2), which
|
|
|
|
# is exactly what the test case checks.
|
2021-09-23 06:33:25 -07:00
|
|
|
device_count = jax.device_count()
|
2020-12-01 02:06:15 +00:00
|
|
|
if device_count % 2 != 0:
|
|
|
|
raise SkipTest('test requires an even number of devices')
|
|
|
|
shape = (device_count, device_count // 2)
|
|
|
|
x = np.arange(prod(shape)).reshape(shape)
|
|
|
|
|
|
|
|
axis_index_groups = np.arange(device_count, dtype=np.int32)
|
2021-06-02 14:02:47 -07:00
|
|
|
axis_index_groups = axis_index_groups.reshape((device_count // 2, 2)).T
|
2020-12-01 02:06:15 +00:00
|
|
|
axis_index_groups = axis_index_groups.tolist()
|
|
|
|
|
2021-08-04 14:46:21 -07:00
|
|
|
@partial(self.pmap, axis_name='i')
|
2020-12-01 02:06:15 +00:00
|
|
|
def fn(x):
|
|
|
|
return lax.all_to_all(x, 'i', 0, 0, axis_index_groups=axis_index_groups)
|
|
|
|
|
|
|
|
expected = np.swapaxes(
|
2021-06-02 14:02:47 -07:00
|
|
|
x.reshape((device_count // 2, 2, device_count // 2)),
|
|
|
|
0, 2).reshape(shape)
|
2020-12-01 02:06:15 +00:00
|
|
|
self.assertAllClose(fn(x), expected, check_dtypes=False)
|
|
|
|
|
2021-03-05 12:24:56 +00:00
|
|
|
@ignore_slow_all_to_all_warning()
|
2020-12-01 02:06:15 +00:00
|
|
|
def testGradOfAllToAllReplicaGroups(self):
|
2021-09-23 06:33:25 -07:00
|
|
|
device_count = jax.device_count()
|
2020-12-01 02:06:15 +00:00
|
|
|
if device_count % 2 != 0:
|
|
|
|
raise SkipTest('test requires an even number of devices')
|
|
|
|
shape = (device_count, device_count // 2, 1)
|
|
|
|
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
|
|
|
w = np.arange(device_count, dtype=np.float32)
|
|
|
|
|
|
|
|
axis_index_groups = np.arange(device_count, dtype=np.int32)
|
|
|
|
axis_index_groups = axis_index_groups.reshape((2, device_count // 2))
|
|
|
|
axis_index_groups = axis_index_groups.tolist()
|
|
|
|
|
2021-08-04 14:46:21 -07:00
|
|
|
@partial(self.pmap, axis_name='i')
|
2020-12-01 02:06:15 +00:00
|
|
|
def fn(x, w):
|
|
|
|
g = lambda x: jnp.sum(lax.all_to_all(x, 'i', 0, 1, axis_index_groups=axis_index_groups) * w)
|
|
|
|
return grad(g)(x)
|
|
|
|
|
|
|
|
expected = np.ones_like(x) * w[:, np.newaxis, np.newaxis]
|
|
|
|
expected = np.swapaxes(
|
|
|
|
expected.reshape((2, device_count // 2, device_count // 2)),
|
|
|
|
1, 2).reshape(shape)
|
|
|
|
self.assertAllClose(fn(x, w), expected, check_dtypes=False)
|
|
|
|
|
2020-04-15 12:43:55 -07:00
|
|
|
def testReshardInput(self):
|
2021-09-23 06:33:25 -07:00
|
|
|
if jax.device_count() < 6:
|
2020-04-15 12:43:55 -07:00
|
|
|
raise SkipTest("testReshardInput requires 6 devices")
|
|
|
|
# Manually construct a ShardedDeviceArray with the wrong sharding for the
|
|
|
|
# subsequent pmap
|
|
|
|
shard_shape = (3,2)
|
2020-08-18 10:17:38 -07:00
|
|
|
shard = jnp.arange(prod(shard_shape)).reshape(shard_shape)
|
2021-09-23 06:33:25 -07:00
|
|
|
bufs = pxla.device_put(shard, jax.devices()[:4], replicate=True)
|
2020-04-15 12:43:55 -07:00
|
|
|
aval = ShapedArray((6,4), shard.dtype)
|
|
|
|
sharding_spec = pxla.ShardingSpec(
|
2021-01-16 17:57:39 +01:00
|
|
|
sharding=map(pxla.Chunked, ([2], [2])),
|
2020-11-06 12:55:17 +00:00
|
|
|
mesh_mapping=map(pxla.ShardedAxis, (0, 1)))
|
2021-08-10 07:15:46 -07:00
|
|
|
arr = pxla.make_sharded_device_array(aval, sharding_spec, bufs)
|
2020-04-15 12:43:55 -07:00
|
|
|
|
2021-08-04 14:46:21 -07:00
|
|
|
r = self.pmap(lambda x: x + 1)(arr)
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(r, arr + 1)
|
2020-04-15 12:43:55 -07:00
|
|
|
self.assertEqual(len(r.device_buffers), 6)
|
|
|
|
|
2021-01-26 19:38:40 -08:00
|
|
|
@ignore_xmap_warning()
|
2020-07-30 12:59:36 -07:00
|
|
|
def testSoftPmapBatchMatmul(self):
|
2021-09-23 06:33:25 -07:00
|
|
|
n = 4 * jax.device_count()
|
2020-07-30 12:59:36 -07:00
|
|
|
xs = np.arange(n * 2 * 3).reshape(n, 2, 3)
|
|
|
|
ys = np.arange(n * 3 * 4).reshape(n, 3, 4)
|
|
|
|
ans = soft_pmap(jnp.dot, 'i')(xs, ys)
|
|
|
|
expected = np.einsum('nij,njk->nik', xs, ys)
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2021-01-26 19:38:40 -08:00
|
|
|
@ignore_xmap_warning()
|
2020-07-30 12:59:36 -07:00
|
|
|
def testSoftPmapBatchMatmulJit(self):
|
2021-09-23 06:33:25 -07:00
|
|
|
n = 4 * jax.device_count()
|
2020-07-30 12:59:36 -07:00
|
|
|
xs = np.arange(n * 2 * 3).reshape(n, 2, 3)
|
|
|
|
ys = np.arange(n * 3 * 4).reshape(n, 3, 4)
|
|
|
|
ans = soft_pmap(jit(jnp.dot), 'i')(xs, ys)
|
|
|
|
expected = np.einsum('nij,njk->nik', xs, ys)
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2021-01-26 19:38:40 -08:00
|
|
|
@ignore_xmap_warning()
|
2020-07-30 12:59:36 -07:00
|
|
|
def testSoftPmapPsumConstant(self):
|
2021-09-23 06:33:25 -07:00
|
|
|
n = 4 * jax.device_count()
|
2020-07-30 12:59:36 -07:00
|
|
|
def f(_):
|
|
|
|
return lax.psum(1, 'i')
|
|
|
|
ans = soft_pmap(f, 'i')(jnp.ones(n))
|
|
|
|
expected = n * np.ones(n)
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2021-01-26 19:38:40 -08:00
|
|
|
@ignore_xmap_warning()
|
2019-06-23 16:41:59 -07:00
|
|
|
def testSoftPmapPsum(self):
|
2021-09-23 06:33:25 -07:00
|
|
|
n = 4 * jax.device_count()
|
2019-06-23 16:41:59 -07:00
|
|
|
def f(x):
|
|
|
|
return x / lax.psum(x, 'i')
|
2020-05-05 14:59:16 -04:00
|
|
|
ans = soft_pmap(f, 'i')(jnp.ones(n))
|
|
|
|
expected = np.ones(n) / n
|
2019-06-23 16:41:59 -07:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2021-01-26 19:38:40 -08:00
|
|
|
@ignore_xmap_warning()
|
2019-06-23 16:41:59 -07:00
|
|
|
def testSoftPmapAxisIndex(self):
|
2021-09-23 06:33:25 -07:00
|
|
|
n = 4 * jax.device_count()
|
2019-06-23 16:41:59 -07:00
|
|
|
def f(x):
|
|
|
|
return x * lax.axis_index('i')
|
2020-05-05 14:59:16 -04:00
|
|
|
ans = soft_pmap(f, 'i')(2 * jnp.ones(n))
|
|
|
|
expected = 2 * np.arange(n)
|
2019-06-23 16:41:59 -07:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2021-01-26 19:38:40 -08:00
|
|
|
@ignore_xmap_warning()
|
2019-06-23 16:41:59 -07:00
|
|
|
def testSoftPmapOfJit(self):
|
2021-09-23 06:33:25 -07:00
|
|
|
n = 4 * jax.device_count()
|
2019-06-23 16:41:59 -07:00
|
|
|
def f(x):
|
|
|
|
return 3 * x
|
2020-05-05 14:59:16 -04:00
|
|
|
ans = soft_pmap(jit(f), 'i')(np.arange(n))
|
|
|
|
expected = 3 * np.arange(n)
|
2019-06-23 16:41:59 -07:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2021-01-26 19:38:40 -08:00
|
|
|
@ignore_xmap_warning()
|
2021-10-04 17:54:18 -07:00
|
|
|
@unittest.skip("not implemented") # TODO(mattjj): re-implement
|
2019-06-23 16:41:59 -07:00
|
|
|
def testSoftPmapNested(self):
|
2021-09-23 06:33:25 -07:00
|
|
|
n = 4 * jax.device_count()
|
2019-06-23 16:41:59 -07:00
|
|
|
|
|
|
|
@partial(soft_pmap, axis_name='i')
|
|
|
|
@partial(soft_pmap, axis_name='j')
|
|
|
|
def f(x):
|
|
|
|
i_size = lax.psum(1, 'i')
|
|
|
|
return x + lax.axis_index('i') + i_size * lax.axis_index('j')
|
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
ans = f(jnp.zeros((n, n)))
|
|
|
|
expected = np.arange(n ** 2).reshape(n, n).T
|
2019-06-23 16:41:59 -07:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2021-01-26 19:38:40 -08:00
|
|
|
@ignore_xmap_warning()
|
2021-10-04 17:54:18 -07:00
|
|
|
@unittest.skip("not implemented") # TODO(mattjj): re-implement
|
2019-06-23 16:41:59 -07:00
|
|
|
def testGradOfSoftPmap(self):
|
2021-09-23 06:33:25 -07:00
|
|
|
n = 4 * jax.device_count()
|
2019-06-23 16:41:59 -07:00
|
|
|
|
|
|
|
@partial(soft_pmap, axis_name='i')
|
|
|
|
def f(x):
|
|
|
|
return x * lax.axis_index('i')
|
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
ans = grad(lambda x: jnp.sum(f(x)))(jnp.zeros((n, n)))
|
|
|
|
expected = np.repeat(np.arange(n)[:, None], n, axis=1)
|
2019-06-23 16:41:59 -07:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2021-01-26 19:38:40 -08:00
|
|
|
@ignore_xmap_warning()
|
2019-07-06 10:00:08 -07:00
|
|
|
def testSoftPmapDevicePersistence(self):
|
2021-09-23 06:33:25 -07:00
|
|
|
device_count = jax.device_count()
|
2019-07-06 10:00:08 -07:00
|
|
|
shape = (2 * 2 * device_count, 2, 3)
|
|
|
|
|
|
|
|
# check that we can maintain device persistence across calls
|
2020-05-05 14:59:16 -04:00
|
|
|
x = np.arange(prod(shape)).reshape(shape)
|
2019-07-06 10:00:08 -07:00
|
|
|
x = soft_pmap(lambda x: x)(x)
|
2020-04-15 12:43:55 -07:00
|
|
|
self.assertIsInstance(x, pxla.ShardedDeviceArray)
|
2020-05-05 14:59:16 -04:00
|
|
|
x._npy_value = np.float32(np.nan) # can't be coerced to ndarray for xfer
|
2019-07-06 10:00:08 -07:00
|
|
|
x = soft_pmap(lambda x: x)(x) # doesn't crash
|
2020-04-15 12:43:55 -07:00
|
|
|
self.assertIsInstance(x, pxla.ShardedDeviceArray)
|
2019-07-06 10:00:08 -07:00
|
|
|
|
2021-10-04 17:54:18 -07:00
|
|
|
@unittest.skip("the underlying code here is broken") # TODO(mattjj)
|
handle mapped_invars correctly in more places (#2828)
fixes #2822
We didn't handle `pmap`'s `mapped_invars` correctly in all places in #1959. (I'm actually not sure if #1959 introduced the bug where things were working before, or just refactored it in terms of `mapped_invars`, though my guess is that because the information now contained in `mapped_invars` was implicitly contained in the pmapped jaxpr's `constvars` and `env_vars` that it was working correctly before #1959.) In particular, in #1959 we:
1. assumed the `mapped_invars` parameter of xla_pmap_p was only populated after partial_eval and set to None otherwise (i.e. staging out for a jit or a control flow primitive),
2. didn't update it correctly in JVPTrace.process_map (which adds new inputs corresponding to nonzero tangents, and hence `mapped_invars` must be grown),
3. didn't update it correctly in JaxprTrace.process_map (which adds residual inputs to the staged-out version of the primitive),
4. didn't forward it correctly in JaxprTrace.process_map anyway (we were setting it to all-true for the staged out eqn for all tracers regardless of what the original `mapped_invars` said),
5. removed the leading axes of all pvs in JaxprTrace.process_map regardless of whether the corresponding entry of `mapped_invars` was True or False.
The reason we didn't notice 2 and 3 was that they only arise when doing control flow (e.g. scan or remat) of pmap involving closed-over tracers (apparently a rare case), since that's the case where we first form a jaxpr (populating `mapped_invars`) and then later have to apply transformations like AD and further partial eval (thus engaging JVPTrace.process_map and JaxprTrace.process_map with a populated `mapped_invars` parameter). It worked in other cases, e.g. when the pmap was not inside control flow or a remat, because in those cases we left `mapped_invars` set to None, indicating all-true of any length (so it didn't matter if we add inputs).
This commit fixes those issues by
1. making `mapped_invars` non-optional,
2. handling `mapped_invars` correctly in
* JaxprTrace.process_map
* JVPTrace.process_map
* ad.map_transpose (since having symbolic-zero cotangents effectively prunes inputs, and having undefined-primal args also prunes inputs)
* ad._eval_subjaxpr_primals (since having undefined-primal args prunes inputs)
3. making the separate cases of calls and maps handled more explicitly by adding a new Primitive.map_primitive boolean attribute (analogous to Primitive.call_primitive), to be revised further in #2829.
This is begging for a more coherent cleanup. For example, we reuse the same Primitive class but tag it with `call_primitive` or `map_primitive` (only one of which can be True); we should instead just have a separate Primitive class for these cases and track the type tag with built-in
Python mechanisms. Moreover, when `call_primitive=True` or `map_primitive=True` implies things about what `params` must be present (`call_jaxpr` and `mapped_invars`). I plan to follow up with those cleanups in #2829, but I wanted to get something working first.
2020-04-24 18:45:34 -07:00
|
|
|
def testSoftPmapAllToAll(self):
|
2021-09-23 06:33:25 -07:00
|
|
|
n = 4 * jax.device_count()
|
2019-06-23 16:41:59 -07:00
|
|
|
def f(x):
|
|
|
|
return lax.all_to_all(x, 'i', 0, 0)
|
2020-05-05 14:59:16 -04:00
|
|
|
ans = soft_pmap(f, 'i')(jnp.arange(n ** 2).reshape(n, n))
|
|
|
|
expected = np.arange(n ** 2).reshape(n, n).T
|
2019-06-23 16:41:59 -07:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2019-07-03 21:15:52 -07:00
|
|
|
def testShardedDeviceArrayBlockUntilReady(self):
|
2021-09-23 06:33:25 -07:00
|
|
|
x = np.arange(jax.device_count())
|
2021-08-04 14:46:21 -07:00
|
|
|
x = self.pmap(lambda x: x)(x)
|
2019-07-08 16:45:01 -07:00
|
|
|
x.block_until_ready() # doesn't crash
|
2019-07-03 21:15:52 -07:00
|
|
|
|
2020-08-10 19:09:34 +02:00
|
|
|
@ignore_jit_of_pmap_warning()
|
enable jit+pmap by merging pxla.py and xla.py
This change is essentially de-duplicating the XLA lowering logic between
xla.py and pxla.py. Only the latter was capable of handling collectives
(aka pmap primitives), which meant that these didn't work:
1. some compositions of jit and pmap, like jit-of-pmap
2. collectives inside initial-style control flow like scan
3. jax.xla_computation on a function involving collectives
By merging the logic into xla.py, now all the lowering machinery works
with everything. Woo!
The pxla.py file still exists and contains mostly dynamic/runtime
components for pmap and functions used only by pmap and collectives
translations. In particular, pxla.py has
* the pmap impl, particularly the dispatching logic for top-level pmaps,
including argument sharding and lazy sharded result persistence
* the ShardedDeviceArray / ShardedDeviceTuple classes
* the dynamic (trace-time) axis environment data structures and logic
and the special axis_index primitive
* the split-axis transformation for soft_pmap
* the PmapPrimitive (just a tagged version of Primitive)
* the static sharding/unsharding logic for pmap-inside-jit/pmap
These things moved over to xla.py
* the logic for lowering pmap primitives, especially the static axis
environment used during xla lowering
This change refactors the translation rule tables a bit. Instead of just
having one table, there are now four, and they contain rules with
slightly different type signatures:
* the `translations` table has rules with the same signatures as always,
i.e. `CompBuilder -> [XlaOperands] -> ParamsDict -> XlaOperandOut`
* the `backend_specific_translations` table is keyed by platform name
strings and has dict values that each have the same type as `translations`
* the `parallel_translations` table is used for primitives modeling
parallel collectives, and so it has rules with signature
`CompBuilder -> [XlaOperands] -> ReplicaGroups -> ParamsDict -> XlaOpOut`
* the `initial_style_translations` table is for the initial-style
control flow primitives (like `scan`), for which the translation rules
themselves lower jaxprs to XLA computations and thus require the static axis
env to be passed in; the rules there have signature
`CompBuilder -> AxisEnv -> [XlaOperands] -> ParamsDict -> XlaOpOut`
* the `call_translations` table is sued for `xla_call` and `xla_pmap`,
i.e. the primitives underlying `jit` and `pmap` respectively, and has
rules with signature
`CompBuilder -> Jaxpr -> AxisEnv -> [XlaOp] -> [XlaOp] -> ParamsDict -> XlaOp`
Having these as separate tables is an uninteresting implementation
detail. The lowering function `_jaxpr_computation` just does a case analysis
on whether the primitive being translated has an entry in any table
(where the `backend_specific_translations` table must be checked before
the `translations` table, since some primitives may be entered in both).
This change fixes #804 also addresses #852, in that the lax control flow
impls for those primitives are now based on Python-level jaxpr
interpreters rather than XLA compilation, but we should probably wait to
close the latter issue until we benchmark and improve things more. This
change at least seems not to be a performance regression: on my machine
the lax control flow tests go from running in ~20s to running in ~14s.
This change also adds a docstring for `jax.xla_computation` and some
basic tests.
2019-07-02 13:17:31 -07:00
|
|
|
def testJitPmapComposition(self):
|
|
|
|
f = lambda x: x - lax.psum(x, 'i')
|
|
|
|
|
2021-09-23 06:33:25 -07:00
|
|
|
shape = (jax.device_count(), 4)
|
2020-05-05 14:59:16 -04:00
|
|
|
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
|
|
|
expected = x - np.sum(x, 0)
|
enable jit+pmap by merging pxla.py and xla.py
This change is essentially de-duplicating the XLA lowering logic between
xla.py and pxla.py. Only the latter was capable of handling collectives
(aka pmap primitives), which meant that these didn't work:
1. some compositions of jit and pmap, like jit-of-pmap
2. collectives inside initial-style control flow like scan
3. jax.xla_computation on a function involving collectives
By merging the logic into xla.py, now all the lowering machinery works
with everything. Woo!
The pxla.py file still exists and contains mostly dynamic/runtime
components for pmap and functions used only by pmap and collectives
translations. In particular, pxla.py has
* the pmap impl, particularly the dispatching logic for top-level pmaps,
including argument sharding and lazy sharded result persistence
* the ShardedDeviceArray / ShardedDeviceTuple classes
* the dynamic (trace-time) axis environment data structures and logic
and the special axis_index primitive
* the split-axis transformation for soft_pmap
* the PmapPrimitive (just a tagged version of Primitive)
* the static sharding/unsharding logic for pmap-inside-jit/pmap
These things moved over to xla.py
* the logic for lowering pmap primitives, especially the static axis
environment used during xla lowering
This change refactors the translation rule tables a bit. Instead of just
having one table, there are now four, and they contain rules with
slightly different type signatures:
* the `translations` table has rules with the same signatures as always,
i.e. `CompBuilder -> [XlaOperands] -> ParamsDict -> XlaOperandOut`
* the `backend_specific_translations` table is keyed by platform name
strings and has dict values that each have the same type as `translations`
* the `parallel_translations` table is used for primitives modeling
parallel collectives, and so it has rules with signature
`CompBuilder -> [XlaOperands] -> ReplicaGroups -> ParamsDict -> XlaOpOut`
* the `initial_style_translations` table is for the initial-style
control flow primitives (like `scan`), for which the translation rules
themselves lower jaxprs to XLA computations and thus require the static axis
env to be passed in; the rules there have signature
`CompBuilder -> AxisEnv -> [XlaOperands] -> ParamsDict -> XlaOpOut`
* the `call_translations` table is sued for `xla_call` and `xla_pmap`,
i.e. the primitives underlying `jit` and `pmap` respectively, and has
rules with signature
`CompBuilder -> Jaxpr -> AxisEnv -> [XlaOp] -> [XlaOp] -> ParamsDict -> XlaOp`
Having these as separate tables is an uninteresting implementation
detail. The lowering function `_jaxpr_computation` just does a case analysis
on whether the primitive being translated has an entry in any table
(where the `backend_specific_translations` table must be checked before
the `translations` table, since some primitives may be entered in both).
This change fixes #804 also addresses #852, in that the lax control flow
impls for those primitives are now based on Python-level jaxpr
interpreters rather than XLA compilation, but we should probably wait to
close the latter issue until we benchmark and improve things more. This
change at least seems not to be a performance regression: on my machine
the lax control flow tests go from running in ~20s to running in ~14s.
This change also adds a docstring for `jax.xla_computation` and some
basic tests.
2019-07-02 13:17:31 -07:00
|
|
|
|
2021-08-04 14:46:21 -07:00
|
|
|
ans = jit(self.pmap(f, 'i'))(x)
|
enable jit+pmap by merging pxla.py and xla.py
This change is essentially de-duplicating the XLA lowering logic between
xla.py and pxla.py. Only the latter was capable of handling collectives
(aka pmap primitives), which meant that these didn't work:
1. some compositions of jit and pmap, like jit-of-pmap
2. collectives inside initial-style control flow like scan
3. jax.xla_computation on a function involving collectives
By merging the logic into xla.py, now all the lowering machinery works
with everything. Woo!
The pxla.py file still exists and contains mostly dynamic/runtime
components for pmap and functions used only by pmap and collectives
translations. In particular, pxla.py has
* the pmap impl, particularly the dispatching logic for top-level pmaps,
including argument sharding and lazy sharded result persistence
* the ShardedDeviceArray / ShardedDeviceTuple classes
* the dynamic (trace-time) axis environment data structures and logic
and the special axis_index primitive
* the split-axis transformation for soft_pmap
* the PmapPrimitive (just a tagged version of Primitive)
* the static sharding/unsharding logic for pmap-inside-jit/pmap
These things moved over to xla.py
* the logic for lowering pmap primitives, especially the static axis
environment used during xla lowering
This change refactors the translation rule tables a bit. Instead of just
having one table, there are now four, and they contain rules with
slightly different type signatures:
* the `translations` table has rules with the same signatures as always,
i.e. `CompBuilder -> [XlaOperands] -> ParamsDict -> XlaOperandOut`
* the `backend_specific_translations` table is keyed by platform name
strings and has dict values that each have the same type as `translations`
* the `parallel_translations` table is used for primitives modeling
parallel collectives, and so it has rules with signature
`CompBuilder -> [XlaOperands] -> ReplicaGroups -> ParamsDict -> XlaOpOut`
* the `initial_style_translations` table is for the initial-style
control flow primitives (like `scan`), for which the translation rules
themselves lower jaxprs to XLA computations and thus require the static axis
env to be passed in; the rules there have signature
`CompBuilder -> AxisEnv -> [XlaOperands] -> ParamsDict -> XlaOpOut`
* the `call_translations` table is sued for `xla_call` and `xla_pmap`,
i.e. the primitives underlying `jit` and `pmap` respectively, and has
rules with signature
`CompBuilder -> Jaxpr -> AxisEnv -> [XlaOp] -> [XlaOp] -> ParamsDict -> XlaOp`
Having these as separate tables is an uninteresting implementation
detail. The lowering function `_jaxpr_computation` just does a case analysis
on whether the primitive being translated has an entry in any table
(where the `backend_specific_translations` table must be checked before
the `translations` table, since some primitives may be entered in both).
This change fixes #804 also addresses #852, in that the lax control flow
impls for those primitives are now based on Python-level jaxpr
interpreters rather than XLA compilation, but we should probably wait to
close the latter issue until we benchmark and improve things more. This
change at least seems not to be a performance regression: on my machine
the lax control flow tests go from running in ~20s to running in ~14s.
This change also adds a docstring for `jax.xla_computation` and some
basic tests.
2019-07-02 13:17:31 -07:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2021-08-04 14:46:21 -07:00
|
|
|
ans = self.pmap(jit(f), 'i')(x)
|
enable jit+pmap by merging pxla.py and xla.py
This change is essentially de-duplicating the XLA lowering logic between
xla.py and pxla.py. Only the latter was capable of handling collectives
(aka pmap primitives), which meant that these didn't work:
1. some compositions of jit and pmap, like jit-of-pmap
2. collectives inside initial-style control flow like scan
3. jax.xla_computation on a function involving collectives
By merging the logic into xla.py, now all the lowering machinery works
with everything. Woo!
The pxla.py file still exists and contains mostly dynamic/runtime
components for pmap and functions used only by pmap and collectives
translations. In particular, pxla.py has
* the pmap impl, particularly the dispatching logic for top-level pmaps,
including argument sharding and lazy sharded result persistence
* the ShardedDeviceArray / ShardedDeviceTuple classes
* the dynamic (trace-time) axis environment data structures and logic
and the special axis_index primitive
* the split-axis transformation for soft_pmap
* the PmapPrimitive (just a tagged version of Primitive)
* the static sharding/unsharding logic for pmap-inside-jit/pmap
These things moved over to xla.py
* the logic for lowering pmap primitives, especially the static axis
environment used during xla lowering
This change refactors the translation rule tables a bit. Instead of just
having one table, there are now four, and they contain rules with
slightly different type signatures:
* the `translations` table has rules with the same signatures as always,
i.e. `CompBuilder -> [XlaOperands] -> ParamsDict -> XlaOperandOut`
* the `backend_specific_translations` table is keyed by platform name
strings and has dict values that each have the same type as `translations`
* the `parallel_translations` table is used for primitives modeling
parallel collectives, and so it has rules with signature
`CompBuilder -> [XlaOperands] -> ReplicaGroups -> ParamsDict -> XlaOpOut`
* the `initial_style_translations` table is for the initial-style
control flow primitives (like `scan`), for which the translation rules
themselves lower jaxprs to XLA computations and thus require the static axis
env to be passed in; the rules there have signature
`CompBuilder -> AxisEnv -> [XlaOperands] -> ParamsDict -> XlaOpOut`
* the `call_translations` table is sued for `xla_call` and `xla_pmap`,
i.e. the primitives underlying `jit` and `pmap` respectively, and has
rules with signature
`CompBuilder -> Jaxpr -> AxisEnv -> [XlaOp] -> [XlaOp] -> ParamsDict -> XlaOp`
Having these as separate tables is an uninteresting implementation
detail. The lowering function `_jaxpr_computation` just does a case analysis
on whether the primitive being translated has an entry in any table
(where the `backend_specific_translations` table must be checked before
the `translations` table, since some primitives may be entered in both).
This change fixes #804 also addresses #852, in that the lax control flow
impls for those primitives are now based on Python-level jaxpr
interpreters rather than XLA compilation, but we should probably wait to
close the latter issue until we benchmark and improve things more. This
change at least seems not to be a performance regression: on my machine
the lax control flow tests go from running in ~20s to running in ~14s.
This change also adds a docstring for `jax.xla_computation` and some
basic tests.
2019-07-02 13:17:31 -07:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2019-07-08 18:21:43 -07:00
|
|
|
def testCompositionWithJitTwice(self):
|
|
|
|
@jit
|
|
|
|
def f(x):
|
|
|
|
y = 2 * x
|
2019-07-09 15:12:02 -07:00
|
|
|
|
2019-07-08 18:21:43 -07:00
|
|
|
@jit
|
|
|
|
def g(z):
|
2022-01-28 08:16:30 -08:00
|
|
|
return self.pmap(lambda x: x[jnp.newaxis] * y)(z)
|
2019-07-09 15:12:02 -07:00
|
|
|
|
2019-07-08 18:21:43 -07:00
|
|
|
return g(x)
|
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
f(np.arange(1.).reshape((1, 1))) # doesn't crash
|
2019-07-08 18:21:43 -07:00
|
|
|
|
2020-08-10 19:09:34 +02:00
|
|
|
@ignore_jit_of_pmap_warning()
|
2019-07-25 18:11:44 -07:00
|
|
|
def testIssue1065(self):
|
|
|
|
# from https://github.com/google/jax/issues/1065
|
2021-09-23 06:33:25 -07:00
|
|
|
device_count = jax.device_count()
|
2019-07-25 18:11:44 -07:00
|
|
|
|
|
|
|
def multi_step_pmap(state, count):
|
2021-08-04 14:46:21 -07:00
|
|
|
@partial(self.pmap, axis_name='x')
|
2019-07-25 18:11:44 -07:00
|
|
|
@jit
|
|
|
|
def exchange_and_multi_step(state):
|
|
|
|
return state
|
|
|
|
|
|
|
|
@jit
|
|
|
|
def time_evolution(state):
|
|
|
|
return lax.fori_loop(0, count, lambda i, s: exchange_and_multi_step(s), state)
|
|
|
|
|
|
|
|
return time_evolution(state)
|
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
multi_step_pmap(jnp.zeros((device_count,)), count=1)
|
2019-07-25 18:11:44 -07:00
|
|
|
|
2019-08-21 16:39:59 -07:00
|
|
|
def testShardedDeviceArrayGetItem(self):
|
|
|
|
f = lambda x: 2 * x
|
2021-08-04 14:46:21 -07:00
|
|
|
f = self.pmap(f, axis_name='i')
|
2019-08-21 16:39:59 -07:00
|
|
|
|
2021-09-23 06:33:25 -07:00
|
|
|
shape = (jax.device_count(), 4)
|
2020-05-05 14:59:16 -04:00
|
|
|
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
2019-08-21 16:39:59 -07:00
|
|
|
|
|
|
|
y = f(x)
|
2020-05-05 14:59:16 -04:00
|
|
|
self.assertIsInstance(y, jnp.ndarray)
|
2019-08-21 16:39:59 -07:00
|
|
|
self.assertIsInstance(y, pxla.ShardedDeviceArray)
|
|
|
|
|
|
|
|
z = y[0] # doesn't crash
|
2019-08-21 20:36:47 -07:00
|
|
|
self.assertAllClose(z, 2 * x[0], check_dtypes=False)
|
2019-08-21 16:39:59 -07:00
|
|
|
|
2021-10-04 17:54:18 -07:00
|
|
|
# TODO(mattjj): this fails with multiple devices (unless we add a jit)
|
|
|
|
# because we assume eager ops (like scan here) can't require more than 1
|
|
|
|
# replica.
|
|
|
|
@unittest.skip("need eager multi-replica support")
|
2019-09-20 07:01:01 -07:00
|
|
|
def testPostProcessMap(self):
|
|
|
|
# test came from https://github.com/google/jax/issues/1369
|
2021-09-23 06:33:25 -07:00
|
|
|
nrep = jax.device_count()
|
2019-09-20 07:01:01 -07:00
|
|
|
|
|
|
|
def pmvm(a, b):
|
|
|
|
a = a.reshape((nrep, -1, a.shape[1]))
|
2021-08-04 14:46:21 -07:00
|
|
|
func = self.pmap(lambda z: jnp.dot(z, b))
|
2019-09-20 07:01:01 -07:00
|
|
|
return func(a).reshape(b.shape)
|
|
|
|
|
2019-09-20 20:45:01 -07:00
|
|
|
n = nrep * 2
|
2021-12-10 10:32:09 -08:00
|
|
|
rng = self.rng()
|
2019-09-20 20:45:01 -07:00
|
|
|
a = rng.randn(n, n)
|
|
|
|
b = rng.randn(n)
|
2019-09-20 07:01:01 -07:00
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
iters = jnp.arange(5)
|
2019-09-20 07:01:01 -07:00
|
|
|
def body(carry, i):
|
|
|
|
return pmvm(a, carry), i
|
|
|
|
ans, _ = lax.scan(body, b, iters)
|
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
expected = np.linalg.matrix_power(a, 5).dot(b)
|
2019-09-20 07:01:01 -07:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2019-09-18 17:21:57 -07:00
|
|
|
def testManyArgs(self):
|
2021-08-04 14:46:21 -07:00
|
|
|
@self.pmap
|
2019-09-18 17:21:57 -07:00
|
|
|
def f(args_list):
|
|
|
|
return sum(args_list)
|
|
|
|
|
|
|
|
vals = list(range(500))
|
2021-09-23 06:33:25 -07:00
|
|
|
ndevices = jax.device_count()
|
2020-05-05 14:59:16 -04:00
|
|
|
self.assertAllClose(f(jnp.array([vals] * ndevices)),
|
2020-06-01 17:19:23 -04:00
|
|
|
jnp.array([sum(vals)] * ndevices))
|
2019-09-18 17:21:57 -07:00
|
|
|
|
2020-06-01 15:28:57 -07:00
|
|
|
def testPostProcessMap2(self):
|
2020-04-21 18:12:02 -07:00
|
|
|
# code from https://github.com/google/jax/issues/2787
|
|
|
|
def vv(x, y):
|
|
|
|
"""Vector-vector multiply"""
|
2020-05-05 14:59:16 -04:00
|
|
|
return jnp.dot(x, y)
|
2020-04-21 18:12:02 -07:00
|
|
|
|
|
|
|
def distributed_matrix_vector(x, y):
|
|
|
|
"""Matrix vector multiply. First batch it and then row by row"""
|
|
|
|
fv = lambda z: lax.map(lambda j: vv(j, y), z)
|
2021-08-04 14:46:21 -07:00
|
|
|
res = self.pmap(fv)(x.reshape((jax.device_count(), -1) + tuple(x.shape[1:])))
|
2020-04-21 18:12:02 -07:00
|
|
|
res = res.reshape(res.shape[0] * res.shape[1], *res.shape[2:])
|
|
|
|
return res
|
|
|
|
|
|
|
|
key = random.PRNGKey(1)
|
2020-04-21 18:27:53 -07:00
|
|
|
x = random.normal(key, (80, 50))
|
2020-04-21 18:12:02 -07:00
|
|
|
batched_mvm = vmap(lambda b: distributed_matrix_vector(x, b), in_axes=0)
|
|
|
|
y = random.normal(key, (10, 50, 1))
|
|
|
|
result = batched_mvm(y)
|
2020-05-05 14:59:16 -04:00
|
|
|
expected = jnp.einsum('ij,njk->nik', x, y)
|
2020-04-21 18:27:53 -07:00
|
|
|
tol = 1e-1 if jtu.device_under_test() == "tpu" else 1e-3
|
|
|
|
self.assertAllClose(result, expected, check_dtypes=False, atol=tol, rtol=tol)
|
2020-04-21 18:12:02 -07:00
|
|
|
|
2020-04-23 13:34:01 -07:00
|
|
|
def testAxisIndexRemat(self):
|
|
|
|
# https://github.com/google/jax/issues/2716
|
|
|
|
n = len(jax.devices())
|
|
|
|
|
|
|
|
def f(key):
|
|
|
|
key = random.fold_in(key, jax.lax.axis_index('i'))
|
|
|
|
return random.bernoulli(key, p=0.5)
|
|
|
|
|
|
|
|
keys = random.split(random.PRNGKey(0), n)
|
2021-08-04 14:46:21 -07:00
|
|
|
self.pmap(jax.remat(f), axis_name='i')(keys)
|
2020-04-23 13:34:01 -07:00
|
|
|
|
handle mapped_invars correctly in more places (#2828)
fixes #2822
We didn't handle `pmap`'s `mapped_invars` correctly in all places in #1959. (I'm actually not sure if #1959 introduced the bug where things were working before, or just refactored it in terms of `mapped_invars`, though my guess is that because the information now contained in `mapped_invars` was implicitly contained in the pmapped jaxpr's `constvars` and `env_vars` that it was working correctly before #1959.) In particular, in #1959 we:
1. assumed the `mapped_invars` parameter of xla_pmap_p was only populated after partial_eval and set to None otherwise (i.e. staging out for a jit or a control flow primitive),
2. didn't update it correctly in JVPTrace.process_map (which adds new inputs corresponding to nonzero tangents, and hence `mapped_invars` must be grown),
3. didn't update it correctly in JaxprTrace.process_map (which adds residual inputs to the staged-out version of the primitive),
4. didn't forward it correctly in JaxprTrace.process_map anyway (we were setting it to all-true for the staged out eqn for all tracers regardless of what the original `mapped_invars` said),
5. removed the leading axes of all pvs in JaxprTrace.process_map regardless of whether the corresponding entry of `mapped_invars` was True or False.
The reason we didn't notice 2 and 3 was that they only arise when doing control flow (e.g. scan or remat) of pmap involving closed-over tracers (apparently a rare case), since that's the case where we first form a jaxpr (populating `mapped_invars`) and then later have to apply transformations like AD and further partial eval (thus engaging JVPTrace.process_map and JaxprTrace.process_map with a populated `mapped_invars` parameter). It worked in other cases, e.g. when the pmap was not inside control flow or a remat, because in those cases we left `mapped_invars` set to None, indicating all-true of any length (so it didn't matter if we add inputs).
This commit fixes those issues by
1. making `mapped_invars` non-optional,
2. handling `mapped_invars` correctly in
* JaxprTrace.process_map
* JVPTrace.process_map
* ad.map_transpose (since having symbolic-zero cotangents effectively prunes inputs, and having undefined-primal args also prunes inputs)
* ad._eval_subjaxpr_primals (since having undefined-primal args prunes inputs)
3. making the separate cases of calls and maps handled more explicitly by adding a new Primitive.map_primitive boolean attribute (analogous to Primitive.call_primitive), to be revised further in #2829.
This is begging for a more coherent cleanup. For example, we reuse the same Primitive class but tag it with `call_primitive` or `map_primitive` (only one of which can be True); we should instead just have a separate Primitive class for these cases and track the type tag with built-in
Python mechanisms. Moreover, when `call_primitive=True` or `map_primitive=True` implies things about what `params` must be present (`call_jaxpr` and `mapped_invars`). I plan to follow up with those cleanups in #2829, but I wanted to get something working first.
2020-04-24 18:45:34 -07:00
|
|
|
def testPmapMapVmapCombinations(self):
|
|
|
|
# https://github.com/google/jax/issues/2822
|
|
|
|
def vv(x, y):
|
|
|
|
"""Vector-vector multiply"""
|
2020-05-05 14:59:16 -04:00
|
|
|
return jnp.dot(x, y)
|
handle mapped_invars correctly in more places (#2828)
fixes #2822
We didn't handle `pmap`'s `mapped_invars` correctly in all places in #1959. (I'm actually not sure if #1959 introduced the bug where things were working before, or just refactored it in terms of `mapped_invars`, though my guess is that because the information now contained in `mapped_invars` was implicitly contained in the pmapped jaxpr's `constvars` and `env_vars` that it was working correctly before #1959.) In particular, in #1959 we:
1. assumed the `mapped_invars` parameter of xla_pmap_p was only populated after partial_eval and set to None otherwise (i.e. staging out for a jit or a control flow primitive),
2. didn't update it correctly in JVPTrace.process_map (which adds new inputs corresponding to nonzero tangents, and hence `mapped_invars` must be grown),
3. didn't update it correctly in JaxprTrace.process_map (which adds residual inputs to the staged-out version of the primitive),
4. didn't forward it correctly in JaxprTrace.process_map anyway (we were setting it to all-true for the staged out eqn for all tracers regardless of what the original `mapped_invars` said),
5. removed the leading axes of all pvs in JaxprTrace.process_map regardless of whether the corresponding entry of `mapped_invars` was True or False.
The reason we didn't notice 2 and 3 was that they only arise when doing control flow (e.g. scan or remat) of pmap involving closed-over tracers (apparently a rare case), since that's the case where we first form a jaxpr (populating `mapped_invars`) and then later have to apply transformations like AD and further partial eval (thus engaging JVPTrace.process_map and JaxprTrace.process_map with a populated `mapped_invars` parameter). It worked in other cases, e.g. when the pmap was not inside control flow or a remat, because in those cases we left `mapped_invars` set to None, indicating all-true of any length (so it didn't matter if we add inputs).
This commit fixes those issues by
1. making `mapped_invars` non-optional,
2. handling `mapped_invars` correctly in
* JaxprTrace.process_map
* JVPTrace.process_map
* ad.map_transpose (since having symbolic-zero cotangents effectively prunes inputs, and having undefined-primal args also prunes inputs)
* ad._eval_subjaxpr_primals (since having undefined-primal args prunes inputs)
3. making the separate cases of calls and maps handled more explicitly by adding a new Primitive.map_primitive boolean attribute (analogous to Primitive.call_primitive), to be revised further in #2829.
This is begging for a more coherent cleanup. For example, we reuse the same Primitive class but tag it with `call_primitive` or `map_primitive` (only one of which can be True); we should instead just have a separate Primitive class for these cases and track the type tag with built-in
Python mechanisms. Moreover, when `call_primitive=True` or `map_primitive=True` implies things about what `params` must be present (`call_jaxpr` and `mapped_invars`). I plan to follow up with those cleanups in #2829, but I wanted to get something working first.
2020-04-24 18:45:34 -07:00
|
|
|
|
|
|
|
def matrix_vector(x, y, parallel=True):
|
|
|
|
"""Matrix vector multiply. First batch it and then row by row"""
|
|
|
|
fv = lambda z: lax.map(lambda j: vv(j, y), z)
|
|
|
|
if parallel:
|
|
|
|
# split leading axis in two
|
|
|
|
new_x = x.reshape((jax.device_count(), -1, *x.shape[1:]))
|
|
|
|
# apply map
|
2021-08-04 14:46:21 -07:00
|
|
|
new_res = self.pmap(fv)(new_x)
|
handle mapped_invars correctly in more places (#2828)
fixes #2822
We didn't handle `pmap`'s `mapped_invars` correctly in all places in #1959. (I'm actually not sure if #1959 introduced the bug where things were working before, or just refactored it in terms of `mapped_invars`, though my guess is that because the information now contained in `mapped_invars` was implicitly contained in the pmapped jaxpr's `constvars` and `env_vars` that it was working correctly before #1959.) In particular, in #1959 we:
1. assumed the `mapped_invars` parameter of xla_pmap_p was only populated after partial_eval and set to None otherwise (i.e. staging out for a jit or a control flow primitive),
2. didn't update it correctly in JVPTrace.process_map (which adds new inputs corresponding to nonzero tangents, and hence `mapped_invars` must be grown),
3. didn't update it correctly in JaxprTrace.process_map (which adds residual inputs to the staged-out version of the primitive),
4. didn't forward it correctly in JaxprTrace.process_map anyway (we were setting it to all-true for the staged out eqn for all tracers regardless of what the original `mapped_invars` said),
5. removed the leading axes of all pvs in JaxprTrace.process_map regardless of whether the corresponding entry of `mapped_invars` was True or False.
The reason we didn't notice 2 and 3 was that they only arise when doing control flow (e.g. scan or remat) of pmap involving closed-over tracers (apparently a rare case), since that's the case where we first form a jaxpr (populating `mapped_invars`) and then later have to apply transformations like AD and further partial eval (thus engaging JVPTrace.process_map and JaxprTrace.process_map with a populated `mapped_invars` parameter). It worked in other cases, e.g. when the pmap was not inside control flow or a remat, because in those cases we left `mapped_invars` set to None, indicating all-true of any length (so it didn't matter if we add inputs).
This commit fixes those issues by
1. making `mapped_invars` non-optional,
2. handling `mapped_invars` correctly in
* JaxprTrace.process_map
* JVPTrace.process_map
* ad.map_transpose (since having symbolic-zero cotangents effectively prunes inputs, and having undefined-primal args also prunes inputs)
* ad._eval_subjaxpr_primals (since having undefined-primal args prunes inputs)
3. making the separate cases of calls and maps handled more explicitly by adding a new Primitive.map_primitive boolean attribute (analogous to Primitive.call_primitive), to be revised further in #2829.
This is begging for a more coherent cleanup. For example, we reuse the same Primitive class but tag it with `call_primitive` or `map_primitive` (only one of which can be True); we should instead just have a separate Primitive class for these cases and track the type tag with built-in
Python mechanisms. Moreover, when `call_primitive=True` or `map_primitive=True` implies things about what `params` must be present (`call_jaxpr` and `mapped_invars`). I plan to follow up with those cleanups in #2829, but I wanted to get something working first.
2020-04-24 18:45:34 -07:00
|
|
|
# reshape back out
|
|
|
|
res = new_res.reshape(x.shape[0], *new_res.shape[2:])
|
|
|
|
else:
|
|
|
|
res = fv(x)
|
|
|
|
return res
|
|
|
|
|
|
|
|
x = random.normal(random.PRNGKey(1), (80, 5))
|
|
|
|
y = random.normal(random.PRNGKey(1), (10, 5))
|
|
|
|
|
|
|
|
result1 = vmap(lambda b: matrix_vector(x, b, True))(y) # vmap + pmap
|
|
|
|
result2 = lax.map(lambda b: matrix_vector(x, b, False), y) # map + map
|
2021-10-13 10:56:21 -04:00
|
|
|
with ignore_jit_of_pmap_warning():
|
|
|
|
result3 = lax.map(lambda b: matrix_vector(x, b, True), y) # map + pmap
|
2020-05-05 14:59:16 -04:00
|
|
|
result4 = jnp.stack([matrix_vector(x, b, False) for b in y]) # none + map
|
handle mapped_invars correctly in more places (#2828)
fixes #2822
We didn't handle `pmap`'s `mapped_invars` correctly in all places in #1959. (I'm actually not sure if #1959 introduced the bug where things were working before, or just refactored it in terms of `mapped_invars`, though my guess is that because the information now contained in `mapped_invars` was implicitly contained in the pmapped jaxpr's `constvars` and `env_vars` that it was working correctly before #1959.) In particular, in #1959 we:
1. assumed the `mapped_invars` parameter of xla_pmap_p was only populated after partial_eval and set to None otherwise (i.e. staging out for a jit or a control flow primitive),
2. didn't update it correctly in JVPTrace.process_map (which adds new inputs corresponding to nonzero tangents, and hence `mapped_invars` must be grown),
3. didn't update it correctly in JaxprTrace.process_map (which adds residual inputs to the staged-out version of the primitive),
4. didn't forward it correctly in JaxprTrace.process_map anyway (we were setting it to all-true for the staged out eqn for all tracers regardless of what the original `mapped_invars` said),
5. removed the leading axes of all pvs in JaxprTrace.process_map regardless of whether the corresponding entry of `mapped_invars` was True or False.
The reason we didn't notice 2 and 3 was that they only arise when doing control flow (e.g. scan or remat) of pmap involving closed-over tracers (apparently a rare case), since that's the case where we first form a jaxpr (populating `mapped_invars`) and then later have to apply transformations like AD and further partial eval (thus engaging JVPTrace.process_map and JaxprTrace.process_map with a populated `mapped_invars` parameter). It worked in other cases, e.g. when the pmap was not inside control flow or a remat, because in those cases we left `mapped_invars` set to None, indicating all-true of any length (so it didn't matter if we add inputs).
This commit fixes those issues by
1. making `mapped_invars` non-optional,
2. handling `mapped_invars` correctly in
* JaxprTrace.process_map
* JVPTrace.process_map
* ad.map_transpose (since having symbolic-zero cotangents effectively prunes inputs, and having undefined-primal args also prunes inputs)
* ad._eval_subjaxpr_primals (since having undefined-primal args prunes inputs)
3. making the separate cases of calls and maps handled more explicitly by adding a new Primitive.map_primitive boolean attribute (analogous to Primitive.call_primitive), to be revised further in #2829.
This is begging for a more coherent cleanup. For example, we reuse the same Primitive class but tag it with `call_primitive` or `map_primitive` (only one of which can be True); we should instead just have a separate Primitive class for these cases and track the type tag with built-in
Python mechanisms. Moreover, when `call_primitive=True` or `map_primitive=True` implies things about what `params` must be present (`call_jaxpr` and `mapped_invars`). I plan to follow up with those cleanups in #2829, but I wanted to get something working first.
2020-04-24 18:45:34 -07:00
|
|
|
|
|
|
|
self.assertAllClose(result1, result2, check_dtypes=False, atol=1e-3, rtol=1e-3)
|
|
|
|
self.assertAllClose(result1, result3, check_dtypes=False, atol=1e-3, rtol=1e-3)
|
|
|
|
self.assertAllClose(result1, result4, check_dtypes=False, atol=1e-3, rtol=1e-3)
|
|
|
|
|
2020-05-19 15:51:07 -07:00
|
|
|
def testPmapAxisNameError(self):
|
|
|
|
# https://github.com/google/jax/issues/3120
|
|
|
|
a = np.arange(4)[np.newaxis,:]
|
|
|
|
def test(x):
|
|
|
|
return jax.lax.psum(x, axis_name='batch')
|
|
|
|
|
|
|
|
with self.assertRaisesRegex(NameError, "unbound axis name: batch"):
|
2021-08-04 14:46:21 -07:00
|
|
|
self.pmap(test)(a)
|
2020-05-19 15:51:07 -07:00
|
|
|
|
2020-05-19 15:41:03 -07:00
|
|
|
def testPsumOnBooleanDtype(self):
|
|
|
|
# https://github.com/google/jax/issues/3123
|
2021-09-23 06:33:25 -07:00
|
|
|
n = jax.device_count()
|
2020-05-19 15:41:03 -07:00
|
|
|
if n > 1:
|
|
|
|
x = jnp.array([True, False])
|
|
|
|
|
2021-08-04 14:46:21 -07:00
|
|
|
out = self.pmap(lambda x: jax.lax.psum(x, 'i'), 'i')(x)
|
2020-05-19 15:41:03 -07:00
|
|
|
self.assertEqual(list(out), [1, 1])
|
|
|
|
|
2021-08-04 14:46:21 -07:00
|
|
|
out = self.pmap(lambda x: jax.lax.pmean(x, 'i'), 'i')(x)
|
2020-05-19 15:41:03 -07:00
|
|
|
self.assertEqual(list(out), [1/2, 1/2])
|
|
|
|
else:
|
|
|
|
x = jnp.array([True])
|
|
|
|
|
2021-08-04 14:46:21 -07:00
|
|
|
out = self.pmap(lambda x: jax.lax.psum(x, 'i'), 'i')(x)
|
2020-05-19 15:41:03 -07:00
|
|
|
self.assertEqual(list(out), [1])
|
|
|
|
|
2021-08-04 14:46:21 -07:00
|
|
|
out = self.pmap(lambda x: jax.lax.pmean(x, 'i'), 'i')(x)
|
2020-05-19 15:41:03 -07:00
|
|
|
self.assertEqual(list(out), [1])
|
|
|
|
|
2020-06-23 09:29:58 -04:00
|
|
|
def testPsumWithNoAxisDoesntLeakFunctions(self):
|
|
|
|
x = jnp.ones((1, 1024), dtype=np.float32)
|
|
|
|
f = lambda _: x
|
|
|
|
w = weakref.ref(f)
|
2021-08-04 14:46:21 -07:00
|
|
|
g = self.pmap(f)
|
2020-06-23 09:29:58 -04:00
|
|
|
g(np.ones((1,), dtype=np.float32)).block_until_ready()
|
|
|
|
del f, g
|
|
|
|
gc.collect()
|
|
|
|
# 'f' should not be alive at this point; in particular the pmap cache must
|
|
|
|
# not keep it alive.
|
|
|
|
self.assertTrue(w() is None)
|
|
|
|
|
2020-06-12 16:10:45 -07:00
|
|
|
def testJitOfPmapWarningMessage(self):
|
2021-09-23 06:33:25 -07:00
|
|
|
device_count = jax.device_count()
|
2020-06-12 16:10:45 -07:00
|
|
|
|
|
|
|
if device_count == 1:
|
|
|
|
raise SkipTest("test requires at least two devices")
|
|
|
|
|
|
|
|
def foo(x): return x
|
|
|
|
|
|
|
|
with warnings.catch_warnings(record=True) as w:
|
|
|
|
warnings.simplefilter("always")
|
2021-08-04 14:46:21 -07:00
|
|
|
jit(self.pmap(foo))(jnp.arange(device_count))
|
2020-06-12 16:10:45 -07:00
|
|
|
|
|
|
|
self.assertGreaterEqual(len(w), 1)
|
|
|
|
self.assertIn("The jitted function foo includes a pmap",
|
|
|
|
str(w[-1].message))
|
|
|
|
|
2020-07-03 10:00:25 -07:00
|
|
|
def testPsumZeroCotangents(self):
|
|
|
|
# https://github.com/google/jax/issues/3651
|
|
|
|
def loss(params, meta_params):
|
|
|
|
(net, mpo) = params
|
|
|
|
return meta_params * mpo * net
|
|
|
|
|
|
|
|
def inner(meta_params, params):
|
|
|
|
grads = jax.grad(loss)(params, meta_params)
|
|
|
|
grads = lax.psum(grads, axis_name="i")
|
|
|
|
net_grads, mpo_grads = grads
|
|
|
|
net = params[0] + net_grads
|
|
|
|
mpo = params[1]
|
|
|
|
return mpo * net
|
|
|
|
|
|
|
|
def outer(params):
|
|
|
|
meta_params = jnp.array(4.0)
|
|
|
|
return jax.grad(inner)(meta_params, params)
|
|
|
|
|
|
|
|
params = (jnp.array([2.0]), jnp.array([3.0]))
|
2021-08-04 14:46:21 -07:00
|
|
|
self.pmap(outer, axis_name='i')(params) # doesn't crash
|
2020-07-03 10:00:25 -07:00
|
|
|
|
2021-08-04 14:46:21 -07:00
|
|
|
f = self.pmap(outer, axis_name='i')
|
2020-07-03 10:00:25 -07:00
|
|
|
jtu.check_grads(f, (params,), 2, ["fwd", "rev"], 1e-3, 1e-3)
|
|
|
|
|
2020-08-10 19:09:34 +02:00
|
|
|
@ignore_jit_of_pmap_warning()
|
2020-07-30 12:59:36 -07:00
|
|
|
def test_issue_1062(self):
|
|
|
|
# code from https://github.com/google/jax/issues/1062 @shoyer
|
|
|
|
# this tests, among other things, whether ShardedDeviceTuple constants work
|
2021-09-23 06:33:25 -07:00
|
|
|
device_count = jax.device_count()
|
2020-07-30 12:59:36 -07:00
|
|
|
|
|
|
|
@jit
|
|
|
|
def multi_step(state, count):
|
|
|
|
return lax.fori_loop(0, count, lambda i, s: s, state)
|
|
|
|
|
|
|
|
@jit
|
|
|
|
def multi_step_pmap(state, count=2):
|
2021-08-04 14:46:21 -07:00
|
|
|
@partial(self.pmap, axis_name='x')
|
2020-07-30 12:59:36 -07:00
|
|
|
def pmapped_multi_step(state):
|
|
|
|
return multi_step(state, count)
|
|
|
|
|
|
|
|
return pmapped_multi_step(state)
|
|
|
|
|
|
|
|
u = np.ones((device_count, 100))
|
|
|
|
multi_step_pmap(u) # doesn't crash
|
|
|
|
|
2020-09-11 22:40:12 -07:00
|
|
|
@jtu.skip_on_devices("cpu")
|
|
|
|
def test_replicate_backend(self):
|
2021-03-04 00:25:16 +00:00
|
|
|
# TODO(skye): fix backend caching so we always have multiple CPUs available
|
|
|
|
if jax.device_count("cpu") < 4:
|
|
|
|
self.skipTest("test requires 4 CPU device")
|
2020-09-11 22:40:12 -07:00
|
|
|
# https://github.com/google/jax/issues/4223
|
|
|
|
def fn(indices):
|
|
|
|
return jnp.equal(indices, jnp.arange(3)).astype(jnp.float32)
|
2021-08-04 14:46:21 -07:00
|
|
|
mapped_fn = self.pmap(fn, axis_name='i', backend='cpu')
|
|
|
|
mapped_fn = self.pmap(mapped_fn, axis_name='j', backend='cpu')
|
2020-09-11 22:40:12 -07:00
|
|
|
indices = np.array([[[2], [1]], [[0], [0]]])
|
|
|
|
mapped_fn(indices) # doesn't crash
|
2020-07-30 12:59:36 -07:00
|
|
|
|
2020-11-25 15:23:00 -08:00
|
|
|
@ignore_xmap_warning()
|
|
|
|
def testPdotBasic(self):
|
|
|
|
num_devices = jax.device_count()
|
|
|
|
|
|
|
|
def f(x, y):
|
|
|
|
return lax.pdot(x, y, 'i')
|
|
|
|
|
|
|
|
x = jnp.arange(num_devices * 3).reshape(num_devices, 3)
|
|
|
|
y = jnp.arange(num_devices * 5).reshape(num_devices, 5)
|
2021-08-04 14:46:21 -07:00
|
|
|
z = self.pmap(f, axis_name='i', out_axes=None)(x, y)
|
2020-11-25 15:23:00 -08:00
|
|
|
self.assertAllClose(z, jnp.dot(x.T, y))
|
|
|
|
|
2021-02-08 20:24:19 -08:00
|
|
|
@parameterized.named_parameters(
|
|
|
|
{"testcase_name": "_shape={}_axis={}_collective={}".format(
|
|
|
|
jtu.format_shape_dtype_string(shape, dtype),
|
|
|
|
axis, collective.__name__.replace(" ", "")),
|
|
|
|
"shape": shape, "dtype": dtype, "axis": axis,
|
|
|
|
"collective": collective, "bulk_op": bulk_op}
|
2021-02-11 08:30:37 -08:00
|
|
|
for collective, bulk_op in [
|
|
|
|
(parallel.pargmax, jnp.argmax),
|
|
|
|
(parallel.pargmin, jnp.argmin)
|
|
|
|
]
|
2021-02-08 20:24:19 -08:00
|
|
|
for dtype in [np.float32, np.int32]
|
2021-02-11 08:30:37 -08:00
|
|
|
for shape in [(4,), (2, 2), (2, 4), (4, 2)]
|
2021-02-08 20:24:19 -08:00
|
|
|
for axis in range(len(shape))
|
|
|
|
)
|
|
|
|
def testArgAllReduce(self, shape, dtype, axis, collective, bulk_op):
|
2021-09-23 06:33:25 -07:00
|
|
|
if jax.device_count() < shape[axis]:
|
2021-02-11 08:30:37 -08:00
|
|
|
raise SkipTest(f"test requires at least {shape[axis]} devices")
|
|
|
|
if (jtu.device_under_test() == 'cpu' and
|
|
|
|
np.issubdtype(dtype, np.floating) and
|
|
|
|
len(shape) > 1):
|
|
|
|
raise SkipTest("skipped on cpu due to strange failures") # TODO(mattjj)
|
2021-02-08 20:24:19 -08:00
|
|
|
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
x = rng(shape, dtype)
|
2021-08-04 14:46:21 -07:00
|
|
|
ans = self.pmap(lambda x: collective(x, 'i'), in_axes=axis, out_axes=None,
|
2021-02-08 20:24:19 -08:00
|
|
|
axis_name='i')(x)
|
|
|
|
expected = bulk_op(x, axis=axis)
|
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2021-03-25 12:43:31 -07:00
|
|
|
@parameterized.named_parameters(
|
|
|
|
{"testcase_name": "_dtype={}".format(
|
|
|
|
jtu.format_shape_dtype_string((), dtype)),
|
|
|
|
"dtype": dtype}
|
|
|
|
for dtype in [np.float32, np.int32]
|
|
|
|
)
|
|
|
|
def testPmapDtype(self, dtype):
|
|
|
|
# Regression test for https://github.com/google/jax/issues/6022
|
2021-08-04 14:46:21 -07:00
|
|
|
@partial(self.pmap, axis_name='i')
|
2021-03-25 12:43:31 -07:00
|
|
|
def func(_):
|
|
|
|
return jax.lax.psum(dtype(0), axis_name='i')
|
2021-09-23 06:33:25 -07:00
|
|
|
unused_arg = jnp.arange(jax.device_count())
|
2021-03-25 12:43:31 -07:00
|
|
|
out_dtype = func(unused_arg).dtype
|
|
|
|
self.assertEqual(out_dtype, dtype)
|
|
|
|
|
2021-07-29 10:34:43 -07:00
|
|
|
def test_num_replicas_with_switch(self):
|
|
|
|
# https://github.com/google/jax/issues/7411
|
|
|
|
def identity(x):
|
|
|
|
return x
|
|
|
|
|
|
|
|
def cond_of_pmap(x):
|
|
|
|
y = lax.cond(True, jax.pmap(identity), jax.pmap(identity), x)
|
|
|
|
return y
|
|
|
|
|
2021-10-13 10:56:21 -04:00
|
|
|
with ignore_jit_of_pmap_warning():
|
|
|
|
cond_of_pmap(jnp.zeros((jax.device_count(), 2)))
|
2021-07-29 10:34:43 -07:00
|
|
|
|
2021-09-07 08:50:02 -07:00
|
|
|
def test_static_argnum_on_method(self):
|
|
|
|
|
|
|
|
class A:
|
|
|
|
|
|
|
|
@partial(self.pmap, static_broadcasted_argnums=(0,))
|
|
|
|
def my_func_pmap(self, x):
|
|
|
|
return x + 2
|
|
|
|
|
|
|
|
A().my_func_pmap(jnp.asarray([3] * jax.device_count()))
|
|
|
|
|
2021-07-19 13:11:38 -04:00
|
|
|
def test_pmap_error_on_non_hashable_static_argument(self):
|
|
|
|
f = lambda x, y: x + 3
|
|
|
|
pmapped_f = self.pmap(f, static_broadcasted_argnums=(1,))
|
|
|
|
|
|
|
|
inputs = np.asarray([1] * jax.device_count())
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError, "Non-hashable static arguments are not supported.*"):
|
|
|
|
pmapped_f(inputs, np.asarray(1))
|
|
|
|
|
2021-10-13 11:06:17 -07:00
|
|
|
@parameterized.named_parameters(
|
|
|
|
{"testcase_name": f"_axis_size={axis_size}", "axis_size": axis_size}
|
|
|
|
for axis_size in [1, 2])
|
|
|
|
def test_grad_of_pmap_compilation_caching(self, axis_size):
|
|
|
|
if len(jax.local_devices()) < axis_size:
|
|
|
|
raise SkipTest("too few devices for test")
|
|
|
|
|
|
|
|
@jax.pmap
|
|
|
|
def f(x):
|
|
|
|
return jnp.sin(x)
|
2021-07-19 13:11:38 -04:00
|
|
|
|
2021-10-13 11:06:17 -07:00
|
|
|
x = jnp.ones(axis_size)
|
|
|
|
f(x) # warm-up any dispatching compilations
|
|
|
|
|
|
|
|
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
|
|
|
|
_, f_bwd = jax.vjp(f, x)
|
|
|
|
_ = f_bwd(x)
|
|
|
|
self.assertEqual(count[0], 2) # one for fwd, one for bwd
|
|
|
|
|
|
|
|
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
|
|
|
|
_ = jax.vjp(f, x)
|
|
|
|
_ = f_bwd(x)
|
|
|
|
self.assertEqual(count[0], 0) # cache hits on fwd and bwd
|
2020-06-12 16:10:45 -07:00
|
|
|
|
2021-11-17 10:01:14 -08:00
|
|
|
def testSizeOverflow(self):
|
|
|
|
x = jnp.arange(1)
|
|
|
|
x = self.pmap(lambda _: jnp.ones([8, 267736, 1024], dtype=jnp.int8))(x)
|
|
|
|
self.assertEqual(x.size, 8 * 267736 * 1024)
|
|
|
|
self.assertEqual(type(x.size), int)
|
|
|
|
|
2021-08-17 06:11:07 -07:00
|
|
|
class CppPmapTest(PythonPmapTest):
|
2021-08-04 14:46:21 -07:00
|
|
|
|
2021-08-17 06:11:07 -07:00
|
|
|
@property
|
|
|
|
def pmap(self):
|
2021-11-18 14:55:19 -05:00
|
|
|
return src_api._cpp_pmap
|
2021-08-04 14:46:21 -07:00
|
|
|
|
2022-02-18 03:18:19 -08:00
|
|
|
def pmap_fast_path_is_enabled(self):
|
|
|
|
num_devices = jax.device_count()
|
|
|
|
f = jax.pmap(lambda x: x+1)
|
|
|
|
size = f._cache_size()
|
|
|
|
f(np.zeros([num_devices], dtype=np.float32))
|
|
|
|
self.assertEqual(f._cache_size(), size+1)
|
|
|
|
|
2021-08-04 14:46:21 -07:00
|
|
|
|
2020-06-15 09:10:40 -07:00
|
|
|
class VmapOfPmapTest(jtu.JaxTestCase):
|
|
|
|
|
2021-04-13 10:27:48 +00:00
|
|
|
# TODO(apaszke)
|
|
|
|
@parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({
|
|
|
|
"testcase_name": f"{shapes}_{vmap_in_axes}_{vmap_out_axes}_{pmap_in_axes}_{pmap_out_axes}",
|
Add support for non-zero (but still not-None) out_axes in pmap
Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`),
but its semantics would match the specification of `out_axes=0` (i.e.
all outputs should be stacked along the first axis). This patch makes it
possible to specify non-zero values for out_axes, but more importantly
it lays down the groundwork for `xmap` which will have to use some
extremely similar (if not the same) code paths.
One thing to note is that when I started this implementation I was also
planning to add support for `out_axes=None`, which would allow us to
stop using the `unbroadcast` hack, and most of the code is written with
that in mind. Unfortunately it turned out that the correct
implementation of the transpose rule for maps that do allow unmapped
outputs would require me to pretty much simulate what avals-with-names
is supposed to achieve. Technically replicated outputs should work
today, for as long as the user does not do reverse-mode AD of `pmap`.
But I decided that it's better to just disable them altogether until we
can get the full and correct behavior.
* Implementation details *
This patch is significantly more involved than the one that implemented
general `in_axes` support. That previous one at least had the foundation
of `mapped_invars` which already behaved pretty similarly to general
`in_axes`. From a quick glance one might think that `out_axes` should
behave similarly to `in_axes`, but it turns out that this is not the
case, at least not if we're interested in keeping those primitives
final-style.
** Thunking **
The biggest difficulty with handling `out_axes` in final style
primitives is that we want to treat them as a prefix of the output
pytree, but we don't know the structure of the output pytree until the
user function is evaluated! And the user function is not evaluated until
we've applied all transforms and reached the impl rule! The solution to
this problem is "straightforward": instead of putting `out_axes` as a
primitive parameter, we bundle an `out_axes_thunk` which can only be
called successfully after the wrapped function has been executed. The
thunk returns a list of flat `out_axes`, expanded to the output pytree.
However, the thunking presents us with two problems:
*** Transformations ***
Each transformation that modifies the number of outputs needs to ensure
that the thunk is updated to reflect the new values. To make things
worse a lot of the transforms can learn the number of added outputs
_only after the wrapped function is evaluated_, which leads to the
following "time travel" pattern that can be found in most `Trace`s:
```py
@lu.transformation_with_aux
def compute_output_statistic(*args, **kwargs):
outputs = yield args, kwargs
yield outputs, compute_statistic(outputs)
wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun)
def new_out_axes_thunk():
old_out_axes = params['out_axes_thunk']()
return compute_new_out_axes(old_out_axes(), output_statistic())
primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk))
```
The reason why we have to structure the code this way is that we can
only specify a new `out_axes_thunk` before we bind the primitive, but we
need the outputs of bind to know how to update the `out_axes_thunk`. To
make things worse, the implementation of `bind` is allowed to make a
call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_.
This means that we cannot compute the output statistic in the
implementation of the transformation, but we have to use an extra
`lu.transformation_with_aux` for that (this populates the statistic
store immediately after `wrapped_fun` is evaluated).
The `compute_statistic` function depends on the transform in question.
E.g. in the JVP trace it counts the number of non-zero tangent results.
The situation is of course further complicated when we take
`post_process_map` into account. The new `process_env_traces` now always
sets up this funny time travel trampoline just in case it ends up being
necessary, and `post_process_map` is now expected to return `(outputs,
(todo, out_axes_transform))` instead of just `(outputs, todo)`.
*** Compilation cache ***
Because the `out_axes_thunk`s are now arguments to a _global_
compilation cache (in the form of `lu.cache` decorator on
`parallel_callable`), we have to ensure that they implement `hash` and
`==`. This is what forces us to add some slightly weird helpers such as
`_hashable_function` and `_ignore_elem_list`. The code that uses those
makes an assumption that the output pytree depends deterministically on
the identity of the wrapped function, which I think is in line with
general JAX assumptions. Otherwise the cache would depend on the
identity of the thunk, which changes with every function invocation.
Relaxing the global constraint on the cache (e.g. allowing each
`pmap(f)` instance to have a separate cache) would make this easier too.
* Why final style? *
Now, making the primitives initial-style would remove the necessity for
thunking, because we could have obtained the output pytree right when
the function is wrapped. I assumed there is a good argument for making
`pmap` pretend that it's a final-style primitive, but I'm not sure why
that is? I hope it's something better than just avoiding a single jaxpr
tracing.
2020-11-09 17:23:16 +00:00
|
|
|
"shapes": shapes,
|
|
|
|
"vmap_in_axes": vmap_in_axes, "vmap_out_axes": vmap_out_axes,
|
2021-04-13 10:27:48 +00:00
|
|
|
"pmap_in_axes": pmap_in_axes, "pmap_out_axes": pmap_out_axes
|
|
|
|
} for arg_shapes in s(compatible_shapes)
|
|
|
|
for num_args in s(range(1, 4))
|
|
|
|
for shapes in s(list(it.combinations_with_replacement(arg_shapes, num_args)))
|
|
|
|
for vmap_in_axes in s(all_bdims(*shapes, pmap=False))
|
|
|
|
for pmap_in_axes in s(all_bdims(*shapes, pmap=True))
|
|
|
|
for vmap_out_axes in s(out_bdims(shapes[0], False))
|
|
|
|
for pmap_out_axes in s(out_bdims(shapes[0], True))
|
|
|
|
)))
|
Add support for non-zero (but still not-None) out_axes in pmap
Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`),
but its semantics would match the specification of `out_axes=0` (i.e.
all outputs should be stacked along the first axis). This patch makes it
possible to specify non-zero values for out_axes, but more importantly
it lays down the groundwork for `xmap` which will have to use some
extremely similar (if not the same) code paths.
One thing to note is that when I started this implementation I was also
planning to add support for `out_axes=None`, which would allow us to
stop using the `unbroadcast` hack, and most of the code is written with
that in mind. Unfortunately it turned out that the correct
implementation of the transpose rule for maps that do allow unmapped
outputs would require me to pretty much simulate what avals-with-names
is supposed to achieve. Technically replicated outputs should work
today, for as long as the user does not do reverse-mode AD of `pmap`.
But I decided that it's better to just disable them altogether until we
can get the full and correct behavior.
* Implementation details *
This patch is significantly more involved than the one that implemented
general `in_axes` support. That previous one at least had the foundation
of `mapped_invars` which already behaved pretty similarly to general
`in_axes`. From a quick glance one might think that `out_axes` should
behave similarly to `in_axes`, but it turns out that this is not the
case, at least not if we're interested in keeping those primitives
final-style.
** Thunking **
The biggest difficulty with handling `out_axes` in final style
primitives is that we want to treat them as a prefix of the output
pytree, but we don't know the structure of the output pytree until the
user function is evaluated! And the user function is not evaluated until
we've applied all transforms and reached the impl rule! The solution to
this problem is "straightforward": instead of putting `out_axes` as a
primitive parameter, we bundle an `out_axes_thunk` which can only be
called successfully after the wrapped function has been executed. The
thunk returns a list of flat `out_axes`, expanded to the output pytree.
However, the thunking presents us with two problems:
*** Transformations ***
Each transformation that modifies the number of outputs needs to ensure
that the thunk is updated to reflect the new values. To make things
worse a lot of the transforms can learn the number of added outputs
_only after the wrapped function is evaluated_, which leads to the
following "time travel" pattern that can be found in most `Trace`s:
```py
@lu.transformation_with_aux
def compute_output_statistic(*args, **kwargs):
outputs = yield args, kwargs
yield outputs, compute_statistic(outputs)
wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun)
def new_out_axes_thunk():
old_out_axes = params['out_axes_thunk']()
return compute_new_out_axes(old_out_axes(), output_statistic())
primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk))
```
The reason why we have to structure the code this way is that we can
only specify a new `out_axes_thunk` before we bind the primitive, but we
need the outputs of bind to know how to update the `out_axes_thunk`. To
make things worse, the implementation of `bind` is allowed to make a
call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_.
This means that we cannot compute the output statistic in the
implementation of the transformation, but we have to use an extra
`lu.transformation_with_aux` for that (this populates the statistic
store immediately after `wrapped_fun` is evaluated).
The `compute_statistic` function depends on the transform in question.
E.g. in the JVP trace it counts the number of non-zero tangent results.
The situation is of course further complicated when we take
`post_process_map` into account. The new `process_env_traces` now always
sets up this funny time travel trampoline just in case it ends up being
necessary, and `post_process_map` is now expected to return `(outputs,
(todo, out_axes_transform))` instead of just `(outputs, todo)`.
*** Compilation cache ***
Because the `out_axes_thunk`s are now arguments to a _global_
compilation cache (in the form of `lu.cache` decorator on
`parallel_callable`), we have to ensure that they implement `hash` and
`==`. This is what forces us to add some slightly weird helpers such as
`_hashable_function` and `_ignore_elem_list`. The code that uses those
makes an assumption that the output pytree depends deterministically on
the identity of the wrapped function, which I think is in line with
general JAX assumptions. Otherwise the cache would depend on the
identity of the thunk, which changes with every function invocation.
Relaxing the global constraint on the cache (e.g. allowing each
`pmap(f)` instance to have a separate cache) would make this easier too.
* Why final style? *
Now, making the primitives initial-style would remove the necessity for
thunking, because we could have obtained the output pytree right when
the function is wrapped. I assumed there is a good argument for making
`pmap` pretend that it's a final-style primitive, but I'm not sure why
that is? I hope it's something better than just avoiding a single jaxpr
tracing.
2020-11-09 17:23:16 +00:00
|
|
|
def testVmapOfPmap(self, shapes, vmap_in_axes, pmap_in_axes, vmap_out_axes, pmap_out_axes):
|
2020-06-15 09:10:40 -07:00
|
|
|
vmapped_size = 3
|
2021-09-23 06:33:25 -07:00
|
|
|
pmapped_size = jax.device_count()
|
2020-06-15 09:10:40 -07:00
|
|
|
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
|
|
|
|
def fun(*args):
|
|
|
|
return sum(args)
|
|
|
|
|
Add support for non-zero (but still not-None) out_axes in pmap
Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`),
but its semantics would match the specification of `out_axes=0` (i.e.
all outputs should be stacked along the first axis). This patch makes it
possible to specify non-zero values for out_axes, but more importantly
it lays down the groundwork for `xmap` which will have to use some
extremely similar (if not the same) code paths.
One thing to note is that when I started this implementation I was also
planning to add support for `out_axes=None`, which would allow us to
stop using the `unbroadcast` hack, and most of the code is written with
that in mind. Unfortunately it turned out that the correct
implementation of the transpose rule for maps that do allow unmapped
outputs would require me to pretty much simulate what avals-with-names
is supposed to achieve. Technically replicated outputs should work
today, for as long as the user does not do reverse-mode AD of `pmap`.
But I decided that it's better to just disable them altogether until we
can get the full and correct behavior.
* Implementation details *
This patch is significantly more involved than the one that implemented
general `in_axes` support. That previous one at least had the foundation
of `mapped_invars` which already behaved pretty similarly to general
`in_axes`. From a quick glance one might think that `out_axes` should
behave similarly to `in_axes`, but it turns out that this is not the
case, at least not if we're interested in keeping those primitives
final-style.
** Thunking **
The biggest difficulty with handling `out_axes` in final style
primitives is that we want to treat them as a prefix of the output
pytree, but we don't know the structure of the output pytree until the
user function is evaluated! And the user function is not evaluated until
we've applied all transforms and reached the impl rule! The solution to
this problem is "straightforward": instead of putting `out_axes` as a
primitive parameter, we bundle an `out_axes_thunk` which can only be
called successfully after the wrapped function has been executed. The
thunk returns a list of flat `out_axes`, expanded to the output pytree.
However, the thunking presents us with two problems:
*** Transformations ***
Each transformation that modifies the number of outputs needs to ensure
that the thunk is updated to reflect the new values. To make things
worse a lot of the transforms can learn the number of added outputs
_only after the wrapped function is evaluated_, which leads to the
following "time travel" pattern that can be found in most `Trace`s:
```py
@lu.transformation_with_aux
def compute_output_statistic(*args, **kwargs):
outputs = yield args, kwargs
yield outputs, compute_statistic(outputs)
wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun)
def new_out_axes_thunk():
old_out_axes = params['out_axes_thunk']()
return compute_new_out_axes(old_out_axes(), output_statistic())
primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk))
```
The reason why we have to structure the code this way is that we can
only specify a new `out_axes_thunk` before we bind the primitive, but we
need the outputs of bind to know how to update the `out_axes_thunk`. To
make things worse, the implementation of `bind` is allowed to make a
call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_.
This means that we cannot compute the output statistic in the
implementation of the transformation, but we have to use an extra
`lu.transformation_with_aux` for that (this populates the statistic
store immediately after `wrapped_fun` is evaluated).
The `compute_statistic` function depends on the transform in question.
E.g. in the JVP trace it counts the number of non-zero tangent results.
The situation is of course further complicated when we take
`post_process_map` into account. The new `process_env_traces` now always
sets up this funny time travel trampoline just in case it ends up being
necessary, and `post_process_map` is now expected to return `(outputs,
(todo, out_axes_transform))` instead of just `(outputs, todo)`.
*** Compilation cache ***
Because the `out_axes_thunk`s are now arguments to a _global_
compilation cache (in the form of `lu.cache` decorator on
`parallel_callable`), we have to ensure that they implement `hash` and
`==`. This is what forces us to add some slightly weird helpers such as
`_hashable_function` and `_ignore_elem_list`. The code that uses those
makes an assumption that the output pytree depends deterministically on
the identity of the wrapped function, which I think is in line with
general JAX assumptions. Otherwise the cache would depend on the
identity of the thunk, which changes with every function invocation.
Relaxing the global constraint on the cache (e.g. allowing each
`pmap(f)` instance to have a separate cache) would make this easier too.
* Why final style? *
Now, making the primitives initial-style would remove the necessity for
thunking, because we could have obtained the output pytree right when
the function is wrapped. I assumed there is a good argument for making
`pmap` pretend that it's a final-style primitive, but I'm not sure why
that is? I hope it's something better than just avoiding a single jaxpr
tracing.
2020-11-09 17:23:16 +00:00
|
|
|
final_shapes = map(partial(add_bdim, vmapped_size), vmap_in_axes,
|
|
|
|
map(partial(add_bdim, pmapped_size), pmap_in_axes, shapes))
|
|
|
|
|
|
|
|
def args_slice(vi, pi):
|
|
|
|
return args_slicer(args_slicer(args, vmap_in_axes)(vi), pmap_in_axes)(pi)
|
2020-06-15 09:10:40 -07:00
|
|
|
|
|
|
|
args = [rng(shape, jnp.float32) for shape in final_shapes]
|
Add support for non-zero (but still not-None) out_axes in pmap
Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`),
but its semantics would match the specification of `out_axes=0` (i.e.
all outputs should be stacked along the first axis). This patch makes it
possible to specify non-zero values for out_axes, but more importantly
it lays down the groundwork for `xmap` which will have to use some
extremely similar (if not the same) code paths.
One thing to note is that when I started this implementation I was also
planning to add support for `out_axes=None`, which would allow us to
stop using the `unbroadcast` hack, and most of the code is written with
that in mind. Unfortunately it turned out that the correct
implementation of the transpose rule for maps that do allow unmapped
outputs would require me to pretty much simulate what avals-with-names
is supposed to achieve. Technically replicated outputs should work
today, for as long as the user does not do reverse-mode AD of `pmap`.
But I decided that it's better to just disable them altogether until we
can get the full and correct behavior.
* Implementation details *
This patch is significantly more involved than the one that implemented
general `in_axes` support. That previous one at least had the foundation
of `mapped_invars` which already behaved pretty similarly to general
`in_axes`. From a quick glance one might think that `out_axes` should
behave similarly to `in_axes`, but it turns out that this is not the
case, at least not if we're interested in keeping those primitives
final-style.
** Thunking **
The biggest difficulty with handling `out_axes` in final style
primitives is that we want to treat them as a prefix of the output
pytree, but we don't know the structure of the output pytree until the
user function is evaluated! And the user function is not evaluated until
we've applied all transforms and reached the impl rule! The solution to
this problem is "straightforward": instead of putting `out_axes` as a
primitive parameter, we bundle an `out_axes_thunk` which can only be
called successfully after the wrapped function has been executed. The
thunk returns a list of flat `out_axes`, expanded to the output pytree.
However, the thunking presents us with two problems:
*** Transformations ***
Each transformation that modifies the number of outputs needs to ensure
that the thunk is updated to reflect the new values. To make things
worse a lot of the transforms can learn the number of added outputs
_only after the wrapped function is evaluated_, which leads to the
following "time travel" pattern that can be found in most `Trace`s:
```py
@lu.transformation_with_aux
def compute_output_statistic(*args, **kwargs):
outputs = yield args, kwargs
yield outputs, compute_statistic(outputs)
wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun)
def new_out_axes_thunk():
old_out_axes = params['out_axes_thunk']()
return compute_new_out_axes(old_out_axes(), output_statistic())
primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk))
```
The reason why we have to structure the code this way is that we can
only specify a new `out_axes_thunk` before we bind the primitive, but we
need the outputs of bind to know how to update the `out_axes_thunk`. To
make things worse, the implementation of `bind` is allowed to make a
call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_.
This means that we cannot compute the output statistic in the
implementation of the transformation, but we have to use an extra
`lu.transformation_with_aux` for that (this populates the statistic
store immediately after `wrapped_fun` is evaluated).
The `compute_statistic` function depends on the transform in question.
E.g. in the JVP trace it counts the number of non-zero tangent results.
The situation is of course further complicated when we take
`post_process_map` into account. The new `process_env_traces` now always
sets up this funny time travel trampoline just in case it ends up being
necessary, and `post_process_map` is now expected to return `(outputs,
(todo, out_axes_transform))` instead of just `(outputs, todo)`.
*** Compilation cache ***
Because the `out_axes_thunk`s are now arguments to a _global_
compilation cache (in the form of `lu.cache` decorator on
`parallel_callable`), we have to ensure that they implement `hash` and
`==`. This is what forces us to add some slightly weird helpers such as
`_hashable_function` and `_ignore_elem_list`. The code that uses those
makes an assumption that the output pytree depends deterministically on
the identity of the wrapped function, which I think is in line with
general JAX assumptions. Otherwise the cache would depend on the
identity of the thunk, which changes with every function invocation.
Relaxing the global constraint on the cache (e.g. allowing each
`pmap(f)` instance to have a separate cache) would make this easier too.
* Why final style? *
Now, making the primitives initial-style would remove the necessity for
thunking, because we could have obtained the output pytree right when
the function is wrapped. I assumed there is a good argument for making
`pmap` pretend that it's a final-style primitive, but I'm not sure why
that is? I hope it's something better than just avoiding a single jaxpr
tracing.
2020-11-09 17:23:16 +00:00
|
|
|
ans = vmap(pmap(fun, in_axes=pmap_in_axes, out_axes=pmap_out_axes),
|
|
|
|
in_axes=vmap_in_axes,
|
|
|
|
out_axes=vmap_out_axes)(*args)
|
|
|
|
expected = np.stack(
|
|
|
|
[np.stack([fun(*args_slice(vi, pi)) for pi in range(pmapped_size)], axis=pmap_out_axes)
|
|
|
|
for vi in range(vmapped_size)],
|
|
|
|
axis=vmap_out_axes)
|
2020-06-15 09:10:40 -07:00
|
|
|
self.assertAllClose(ans, expected)
|
|
|
|
|
2022-01-28 08:16:30 -08:00
|
|
|
|
Add support for non-zero (but still not-None) out_axes in pmap
Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`),
but its semantics would match the specification of `out_axes=0` (i.e.
all outputs should be stacked along the first axis). This patch makes it
possible to specify non-zero values for out_axes, but more importantly
it lays down the groundwork for `xmap` which will have to use some
extremely similar (if not the same) code paths.
One thing to note is that when I started this implementation I was also
planning to add support for `out_axes=None`, which would allow us to
stop using the `unbroadcast` hack, and most of the code is written with
that in mind. Unfortunately it turned out that the correct
implementation of the transpose rule for maps that do allow unmapped
outputs would require me to pretty much simulate what avals-with-names
is supposed to achieve. Technically replicated outputs should work
today, for as long as the user does not do reverse-mode AD of `pmap`.
But I decided that it's better to just disable them altogether until we
can get the full and correct behavior.
* Implementation details *
This patch is significantly more involved than the one that implemented
general `in_axes` support. That previous one at least had the foundation
of `mapped_invars` which already behaved pretty similarly to general
`in_axes`. From a quick glance one might think that `out_axes` should
behave similarly to `in_axes`, but it turns out that this is not the
case, at least not if we're interested in keeping those primitives
final-style.
** Thunking **
The biggest difficulty with handling `out_axes` in final style
primitives is that we want to treat them as a prefix of the output
pytree, but we don't know the structure of the output pytree until the
user function is evaluated! And the user function is not evaluated until
we've applied all transforms and reached the impl rule! The solution to
this problem is "straightforward": instead of putting `out_axes` as a
primitive parameter, we bundle an `out_axes_thunk` which can only be
called successfully after the wrapped function has been executed. The
thunk returns a list of flat `out_axes`, expanded to the output pytree.
However, the thunking presents us with two problems:
*** Transformations ***
Each transformation that modifies the number of outputs needs to ensure
that the thunk is updated to reflect the new values. To make things
worse a lot of the transforms can learn the number of added outputs
_only after the wrapped function is evaluated_, which leads to the
following "time travel" pattern that can be found in most `Trace`s:
```py
@lu.transformation_with_aux
def compute_output_statistic(*args, **kwargs):
outputs = yield args, kwargs
yield outputs, compute_statistic(outputs)
wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun)
def new_out_axes_thunk():
old_out_axes = params['out_axes_thunk']()
return compute_new_out_axes(old_out_axes(), output_statistic())
primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk))
```
The reason why we have to structure the code this way is that we can
only specify a new `out_axes_thunk` before we bind the primitive, but we
need the outputs of bind to know how to update the `out_axes_thunk`. To
make things worse, the implementation of `bind` is allowed to make a
call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_.
This means that we cannot compute the output statistic in the
implementation of the transformation, but we have to use an extra
`lu.transformation_with_aux` for that (this populates the statistic
store immediately after `wrapped_fun` is evaluated).
The `compute_statistic` function depends on the transform in question.
E.g. in the JVP trace it counts the number of non-zero tangent results.
The situation is of course further complicated when we take
`post_process_map` into account. The new `process_env_traces` now always
sets up this funny time travel trampoline just in case it ends up being
necessary, and `post_process_map` is now expected to return `(outputs,
(todo, out_axes_transform))` instead of just `(outputs, todo)`.
*** Compilation cache ***
Because the `out_axes_thunk`s are now arguments to a _global_
compilation cache (in the form of `lu.cache` decorator on
`parallel_callable`), we have to ensure that they implement `hash` and
`==`. This is what forces us to add some slightly weird helpers such as
`_hashable_function` and `_ignore_elem_list`. The code that uses those
makes an assumption that the output pytree depends deterministically on
the identity of the wrapped function, which I think is in line with
general JAX assumptions. Otherwise the cache would depend on the
identity of the thunk, which changes with every function invocation.
Relaxing the global constraint on the cache (e.g. allowing each
`pmap(f)` instance to have a separate cache) would make this easier too.
* Why final style? *
Now, making the primitives initial-style would remove the necessity for
thunking, because we could have obtained the output pytree right when
the function is wrapped. I assumed there is a good argument for making
`pmap` pretend that it's a final-style primitive, but I'm not sure why
that is? I hope it's something better than just avoiding a single jaxpr
tracing.
2020-11-09 17:23:16 +00:00
|
|
|
class VmapPmapCollectivesTest(jtu.JaxTestCase):
|
|
|
|
|
2020-09-21 16:25:50 +00:00
|
|
|
@parameterized.named_parameters(
|
|
|
|
{"testcase_name": "_collective={}".format(collective.__name__).replace(" ", ""),
|
|
|
|
"collective": collective}
|
|
|
|
for collective in [lax.psum, lax.pmean, lax.pmax, lax.pmin])
|
|
|
|
def testCollectivesWithVmap(self, collective):
|
2020-08-14 18:22:04 +02:00
|
|
|
def f(map1, map2):
|
|
|
|
@partial(map1, axis_name='i')
|
|
|
|
@partial(map2, axis_name='j')
|
|
|
|
def f(x, y):
|
2020-09-21 16:25:50 +00:00
|
|
|
return x + collective(x.dot(y), ('i', 'j'))
|
2020-08-14 18:22:04 +02:00
|
|
|
return f
|
|
|
|
|
2021-09-23 06:33:25 -07:00
|
|
|
if jax.device_count() < 4:
|
2020-08-14 18:22:04 +02:00
|
|
|
raise SkipTest("test requires at least four devices")
|
|
|
|
x = jnp.ones((2, 2, 64, 64))
|
|
|
|
y = f(jax.pmap, jax.pmap)(x, x)
|
|
|
|
self.assertAllClose(f(jax.vmap, jax.vmap)(x, x), y)
|
|
|
|
self.assertAllClose(f(jax.pmap, jax.vmap)(x, x), y)
|
|
|
|
self.assertAllClose(f(jax.vmap, jax.pmap)(x, x), y)
|
|
|
|
|
2021-11-18 13:40:25 -08:00
|
|
|
@parameterized.named_parameters(
|
|
|
|
{"testcase_name": "_collective={}".format(collective.__name__).replace(" ", ""),
|
|
|
|
"collective": collective}
|
|
|
|
for collective in [lax.psum, lax.pmean, lax.pmax, lax.pmin])
|
|
|
|
def testCollectivesWithVmap2(self, collective):
|
|
|
|
def f(map1, map2):
|
|
|
|
@partial(map1, axis_name='i')
|
|
|
|
@partial(map2, axis_name='j')
|
|
|
|
def f(x, y):
|
|
|
|
return x + collective(x.dot(y), ('i', 'j'))
|
|
|
|
return f
|
|
|
|
|
|
|
|
if jax.device_count() < 8:
|
|
|
|
raise SkipTest("test requires at least eight devices")
|
|
|
|
x = jnp.arange(4*2*64*64).reshape(4, 2, 64, 64)
|
|
|
|
y = f(jax.pmap, jax.pmap)(x, x)
|
|
|
|
self.assertAllClose(f(jax.vmap, jax.vmap)(x, x), y)
|
|
|
|
self.assertAllClose(f(jax.pmap, jax.vmap)(x, x), y)
|
|
|
|
self.assertAllClose(f(jax.vmap, jax.pmap)(x, x), y)
|
|
|
|
|
2020-09-21 14:14:52 +00:00
|
|
|
def testPPermuteWithVmap(self):
|
|
|
|
perm = [(0, 1), (1, 0)]
|
|
|
|
|
|
|
|
def f(map2):
|
|
|
|
@partial(jax.pmap, axis_name='i')
|
|
|
|
@partial(map2)
|
|
|
|
def f(x, y):
|
|
|
|
return x + jax.lax.ppermute(x.dot(y), 'i', perm)
|
|
|
|
return f
|
|
|
|
|
2021-09-23 06:33:25 -07:00
|
|
|
if jax.device_count() < 4:
|
2020-09-21 14:14:52 +00:00
|
|
|
raise SkipTest("test requires at least four devices")
|
|
|
|
x = jnp.ones((2, 2, 64, 64))
|
|
|
|
self.assertAllClose(f(jax.pmap)(x, x), f(jax.vmap)(x, x))
|
|
|
|
|
2021-11-18 13:40:25 -08:00
|
|
|
def testPPermuteAgreesWithVmap(self):
|
|
|
|
if jax.device_count() < 3:
|
|
|
|
raise SkipTest("test requires at least three devices")
|
|
|
|
|
|
|
|
def f(x):
|
|
|
|
return lax.ppermute(x, 'i', [[1, 0], [2, 1], [0, 2]])
|
|
|
|
|
|
|
|
xs = jnp.arange(3) * 10
|
|
|
|
ys = jax.pmap(f, axis_name='i')(xs)
|
|
|
|
zs = jax.vmap(f, axis_name='i')(xs)
|
|
|
|
self.assertAllClose(ys, zs, check_dtypes=True)
|
|
|
|
|
2020-08-28 15:21:50 +00:00
|
|
|
@parameterized.named_parameters(
|
|
|
|
{"testcase_name": f"_split={split_axis}_concat={concat_axis}_vmap={vmap_axis}",
|
|
|
|
"split_axis": split_axis, "concat_axis": concat_axis, "vmap_axis": vmap_axis}
|
|
|
|
for split_axis, concat_axis, vmap_axis in it.product(range(3), range(3), range(4)))
|
|
|
|
@ignore_slow_all_to_all_warning()
|
|
|
|
def testAllToAllInVmap(self, split_axis, concat_axis, vmap_axis):
|
|
|
|
def f(x):
|
|
|
|
return lax.all_to_all(x, 'i', split_axis=split_axis, concat_axis=concat_axis)
|
|
|
|
|
|
|
|
def adj(axis, hidden_axes):
|
|
|
|
for hax in sorted(hidden_axes):
|
|
|
|
if hax <= axis:
|
|
|
|
axis += 1
|
|
|
|
return axis
|
|
|
|
|
|
|
|
def reference(x, split_axis, concat_axis, vmap_axis):
|
|
|
|
pmap_axis = 0
|
|
|
|
vmap_axis = adj(vmap_axis, [pmap_axis])
|
|
|
|
ref = x
|
|
|
|
|
|
|
|
# Step 1.
|
|
|
|
# Adjust the split axis to the real tensor layout and move it to
|
|
|
|
# position 1. Since pmap_axis is always 0 we don't have to adjust it,
|
|
|
|
# but we do have to adjust vmap_axis.
|
|
|
|
split_axis = adj(split_axis, [pmap_axis, vmap_axis])
|
|
|
|
ref = jnp.moveaxis(ref, split_axis, pmap_axis + 1)
|
|
|
|
vmap_axis = vmap_axis + (0 if split_axis < vmap_axis else 1)
|
|
|
|
split_axis = pmap_axis + 1 # split_axes == 1
|
|
|
|
|
|
|
|
# Step 2.
|
|
|
|
# Now, we move pmap_axis to the position indicated by concat_axis.
|
|
|
|
concat_axis = adj(concat_axis, [pmap_axis, split_axis, vmap_axis]) - 1
|
|
|
|
ref = jnp.moveaxis(ref, pmap_axis, concat_axis)
|
|
|
|
pmap_axis = 0
|
|
|
|
vmap_axis = vmap_axis - (1 if concat_axis >= vmap_axis else 0)
|
|
|
|
del split_axis, concat_axis
|
|
|
|
|
|
|
|
# Step 3. vmap_axis always ends in position 1, since out_axes=0.
|
|
|
|
ref = jnp.moveaxis(ref, vmap_axis, 1)
|
|
|
|
return ref
|
|
|
|
|
|
|
|
def verify_ref():
|
|
|
|
# Both the reference and the real implementation of all_to_all batching involve
|
|
|
|
# some pretty complicated axis arithmetic, so it would be good to verify that it's
|
|
|
|
# not the case that the test passes because they're both incorrect. Fortunately, it
|
|
|
|
# is quite easy to write out the shape function for this code, and we know
|
|
|
|
# that it should be equivalent to a bunch of transposes, so the code below verifies
|
|
|
|
# that the reference puts the right dimensions in the right places. Note that we
|
|
|
|
# can't do the same comparison on f, since all_to_all wouldn't allow us to swap axes of
|
|
|
|
# different sizes.
|
|
|
|
start_shape = [2, 3, 4, 5, 6]
|
|
|
|
instance_shape = start_shape.copy()
|
|
|
|
pmap_dim_id = instance_shape.pop(0)
|
|
|
|
vmap_dim_id = instance_shape.pop(vmap_axis)
|
|
|
|
split_axis_id = instance_shape.pop(split_axis)
|
|
|
|
instance_shape.insert(concat_axis, pmap_dim_id)
|
|
|
|
expected_shape = (split_axis_id, vmap_dim_id, *instance_shape)
|
|
|
|
|
|
|
|
x = np.empty(start_shape)
|
|
|
|
self.assertEqual(reference(x, split_axis, concat_axis, vmap_axis).shape,
|
|
|
|
expected_shape)
|
|
|
|
|
|
|
|
verify_ref()
|
|
|
|
|
|
|
|
shape = (jax.device_count(),) * 5
|
|
|
|
x = jnp.arange(np.prod(shape)).reshape(shape)
|
|
|
|
self.assertAllClose(pmap(vmap(f, in_axes=vmap_axis), axis_name='i')(x),
|
|
|
|
reference(x, split_axis, concat_axis, vmap_axis))
|
|
|
|
|
2020-09-22 11:19:06 +00:00
|
|
|
@parameterized.named_parameters(
|
|
|
|
{"testcase_name": f"_split={split_axis}_concat={concat_axis}",
|
|
|
|
"split_axis": split_axis, "concat_axis": concat_axis}
|
|
|
|
for split_axis, concat_axis in it.product(range(3), range(3)))
|
|
|
|
@ignore_slow_all_to_all_warning()
|
|
|
|
def testAllToAllVsVmap(self, split_axis, concat_axis):
|
|
|
|
def f(x):
|
|
|
|
return lax.all_to_all(x, 'i', split_axis=split_axis, concat_axis=concat_axis)
|
|
|
|
|
|
|
|
shape = (jax.device_count(),) * 4
|
|
|
|
x = jnp.arange(np.prod(shape)).reshape(shape)
|
|
|
|
self.assertAllClose(pmap(f, axis_name='i')(x),
|
|
|
|
vmap(f, axis_name='i')(x))
|
|
|
|
|
2020-09-22 13:05:08 +00:00
|
|
|
@parameterized.named_parameters(
|
|
|
|
{"testcase_name": f"_split={split_axis}_concat={concat_axis}_axes={''.join(axes)}",
|
|
|
|
"axes": axes, "split_axis": split_axis, "concat_axis": concat_axis}
|
|
|
|
for axes, split_axis, concat_axis
|
|
|
|
in it.product([('i', 'j'), ('j', 'i')], range(3), range(3)))
|
|
|
|
@ignore_slow_all_to_all_warning()
|
2021-10-04 17:54:18 -07:00
|
|
|
@unittest.skip("multi-axis all_to_all broken after #4835") # TODO(mattjj,apaszke)
|
2020-09-22 13:05:08 +00:00
|
|
|
def testAllToAllMultipleAxesVsVmap(self, axes, split_axis, concat_axis):
|
2021-09-23 06:33:25 -07:00
|
|
|
if jax.device_count() < 4:
|
2020-09-22 13:05:08 +00:00
|
|
|
raise SkipTest("test requires at least four devices")
|
|
|
|
|
|
|
|
def f(x):
|
|
|
|
return lax.all_to_all(x, axes, split_axis=split_axis, concat_axis=concat_axis)
|
|
|
|
|
|
|
|
shape = (2, 2, 4, 4, 4)
|
|
|
|
x = jnp.arange(np.prod(shape)).reshape(shape)
|
|
|
|
self.assertAllClose(pmap(pmap(f, axis_name='j'), axis_name='i')(x),
|
|
|
|
vmap(vmap(f, axis_name='j'), axis_name='i')(x))
|
|
|
|
|
2021-01-13 10:33:03 +00:00
|
|
|
def testAllGatherWithVmap(self):
|
|
|
|
def f(map2):
|
|
|
|
@partial(jax.pmap, axis_name='i')
|
|
|
|
@partial(map2)
|
|
|
|
def f(x):
|
|
|
|
return jax.lax.all_gather(x, 'i')
|
|
|
|
return f
|
|
|
|
|
2021-09-23 06:33:25 -07:00
|
|
|
if jax.device_count() < 4:
|
2021-01-13 10:33:03 +00:00
|
|
|
raise SkipTest("test requires at least four devices")
|
|
|
|
x = jnp.ones((2, 2, 64, 64))
|
|
|
|
self.assertAllClose(f(jax.pmap)(x), f(jax.vmap)(x))
|
|
|
|
|
2020-06-15 07:32:42 -07:00
|
|
|
|
2019-08-26 11:22:58 -07:00
|
|
|
class PmapWithDevicesTest(jtu.JaxTestCase):
|
|
|
|
|
|
|
|
def testAllDevices(self):
|
|
|
|
f = pmap(lambda x: x - lax.psum(x, 'i'), axis_name='i',
|
2021-09-23 06:33:25 -07:00
|
|
|
devices=jax.devices())
|
|
|
|
shape = (jax.device_count(), 4)
|
2020-05-05 14:59:16 -04:00
|
|
|
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
|
|
|
expected = x - np.sum(x, 0)
|
2019-08-26 11:22:58 -07:00
|
|
|
ans = f(x)
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(ans, expected)
|
2019-08-26 11:22:58 -07:00
|
|
|
|
|
|
|
def testOneDevice(self):
|
2021-09-23 06:33:25 -07:00
|
|
|
if jax.device_count() == 1:
|
2019-08-26 11:22:58 -07:00
|
|
|
raise SkipTest("this test requires multiple devices")
|
|
|
|
|
2021-09-23 06:33:25 -07:00
|
|
|
d0 = jax.devices()[0]
|
|
|
|
d1 = jax.devices()[1]
|
2020-05-05 14:59:16 -04:00
|
|
|
f = lambda x: jnp.dot(x, x.T)
|
2019-08-26 11:22:58 -07:00
|
|
|
f0 = pmap(f, devices=[d0])
|
|
|
|
f1 = pmap(f, devices=[d1])
|
2021-12-10 10:32:09 -08:00
|
|
|
x = self.rng().rand(1, 1000, 1000)
|
2019-08-26 11:22:58 -07:00
|
|
|
r0 = f0(x)
|
|
|
|
r1 = f1(x)
|
2020-05-05 14:59:16 -04:00
|
|
|
expected = np.expand_dims(np.dot(x.squeeze(), x.squeeze().T), 0)
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(r0, expected, atol=1e-6, rtol=1e-3)
|
|
|
|
self.assertAllClose(r1, expected, atol=1e-6, rtol=1e-3)
|
2019-08-26 11:22:58 -07:00
|
|
|
|
|
|
|
def testNoDevicesError(self):
|
|
|
|
f = pmap(lambda x: x - lax.psum(x, 'i'), axis_name='i', devices=[])
|
2021-09-23 06:33:25 -07:00
|
|
|
shape = (jax.device_count(), 4)
|
2020-05-05 14:59:16 -04:00
|
|
|
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
2019-08-26 11:22:58 -07:00
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError, "'devices' argument to pmap must be non-empty, or None."):
|
|
|
|
f(x)
|
|
|
|
|
|
|
|
def testBadAxisSizeError(self):
|
2021-09-23 06:33:25 -07:00
|
|
|
if jax.device_count() == 1:
|
2019-08-26 11:22:58 -07:00
|
|
|
raise SkipTest("this test requires multiple devices")
|
|
|
|
|
|
|
|
f = pmap(lambda x: lax.psum(x, 'i'), axis_name='i',
|
2021-09-23 06:33:25 -07:00
|
|
|
devices=jax.devices())
|
2019-08-26 11:22:58 -07:00
|
|
|
with self.assertRaisesRegex(
|
2019-09-27 11:50:21 -07:00
|
|
|
ValueError, r"Leading axis size of input to pmapped function must "
|
|
|
|
r"equal the number of local devices passed to pmap. Got axis_size=1, "
|
|
|
|
r"num_local_devices=\d."):
|
2020-05-05 14:59:16 -04:00
|
|
|
f(jnp.ones(1))
|
2019-08-26 11:22:58 -07:00
|
|
|
|
|
|
|
with self.assertRaisesRegex(
|
2019-09-27 11:50:21 -07:00
|
|
|
ValueError, r"Leading axis size of input to pmapped function must "
|
|
|
|
r"equal the number of local devices passed to pmap. Got axis_size=\d, "
|
|
|
|
r"num_local_devices=\d."):
|
2021-09-23 06:33:25 -07:00
|
|
|
f(jnp.ones(jax.device_count() + 1))
|
2019-08-26 11:22:58 -07:00
|
|
|
|
2021-03-24 12:02:04 -07:00
|
|
|
def testBadAxisSizeErrorNested(self):
|
|
|
|
f = pmap(pmap(lambda x: lax.psum(x, ('i', 'j')),
|
|
|
|
axis_name='j'),
|
|
|
|
axis_name='i',
|
|
|
|
devices=[jax.local_devices()[0]])
|
|
|
|
with self.assertRaisesRegex(
|
|
|
|
ValueError,
|
|
|
|
r"pmapped function requires 4 local devices to run due to nested "
|
|
|
|
r"pmapped or other parallel functions, but only 1 are available."):
|
|
|
|
f(jnp.ones((1, 4)))
|
|
|
|
|
2020-06-19 15:51:12 -07:00
|
|
|
def testNestedPmaps(self):
|
2021-09-23 06:33:25 -07:00
|
|
|
if jax.device_count() % 2 != 0:
|
2020-06-19 15:51:12 -07:00
|
|
|
raise SkipTest
|
|
|
|
|
|
|
|
# Devices specified in outer pmap are OK
|
2021-09-23 06:33:25 -07:00
|
|
|
@partial(pmap, axis_name='i', devices=jax.devices())
|
2019-08-26 11:22:58 -07:00
|
|
|
def foo(x):
|
|
|
|
@partial(pmap, axis_name='j')
|
|
|
|
def bar(y):
|
|
|
|
return lax.psum(y, 'j')
|
|
|
|
return bar(x)
|
|
|
|
|
2021-09-23 06:33:25 -07:00
|
|
|
x = jnp.ones((jax.device_count() // 2, 2))
|
2020-06-19 15:51:12 -07:00
|
|
|
ans = foo(x)
|
|
|
|
expected = x * 2
|
|
|
|
self.assertAllClose(ans, expected)
|
2019-08-26 11:22:58 -07:00
|
|
|
|
2021-12-03 13:34:26 -08:00
|
|
|
def testNestedPmapsBools(self):
|
|
|
|
if jax.device_count() % 2 != 0:
|
|
|
|
raise SkipTest
|
|
|
|
|
|
|
|
# Devices specified in outer pmap are OK
|
|
|
|
@partial(pmap, axis_name='i', devices=jax.devices())
|
|
|
|
def foo(x):
|
|
|
|
@partial(pmap, axis_name='j')
|
|
|
|
def bar(y):
|
|
|
|
return jnp.logical_not(y)
|
|
|
|
return bar(x)
|
|
|
|
|
|
|
|
x = jnp.ones((jax.device_count() // 2, 2), jnp.bool_)
|
|
|
|
ans = foo(x)
|
|
|
|
expected = jnp.zeros((jax.device_count() // 2, 2), jnp.bool_)
|
|
|
|
self.assertAllClose(ans, expected)
|
|
|
|
|
2020-06-19 15:51:12 -07:00
|
|
|
def testNestedPmapsError(self):
|
|
|
|
# Devices specified in inner pmap not OK
|
2019-08-26 11:22:58 -07:00
|
|
|
@partial(pmap, axis_name='i')
|
|
|
|
def foo(x):
|
2021-09-23 06:33:25 -07:00
|
|
|
@partial(pmap, axis_name='j', devices=jax.devices())
|
2019-08-26 11:22:58 -07:00
|
|
|
def bar(y):
|
|
|
|
return lax.psum(y, 'j')
|
|
|
|
return bar(x)
|
|
|
|
|
|
|
|
with self.assertRaisesRegex(
|
2019-08-29 20:25:02 -07:00
|
|
|
ValueError,
|
2020-06-19 15:51:12 -07:00
|
|
|
"Nested pmap with explicit devices argument."):
|
2021-09-23 06:33:25 -07:00
|
|
|
foo(jnp.ones((jax.device_count(), 1)))
|
2019-08-26 11:22:58 -07:00
|
|
|
|
|
|
|
def testJitInPmap(self):
|
2021-09-23 06:33:25 -07:00
|
|
|
@partial(pmap, axis_name='i', devices=jax.devices())
|
2019-08-26 11:22:58 -07:00
|
|
|
def foo(x):
|
|
|
|
@jit
|
|
|
|
def bar(y):
|
|
|
|
return y + 1
|
|
|
|
return lax.psum(bar(x), 'i')
|
|
|
|
|
2021-09-23 06:33:25 -07:00
|
|
|
ndevices = jax.device_count()
|
2020-05-05 14:59:16 -04:00
|
|
|
ans = foo(jnp.ones((ndevices, 1)))
|
|
|
|
expected = np.ones((ndevices, 1), dtype=jnp.float_) * ndevices * 2
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(ans, expected)
|
2019-08-26 11:22:58 -07:00
|
|
|
|
2020-08-10 19:09:34 +02:00
|
|
|
@ignore_jit_of_pmap_warning()
|
2019-08-26 11:22:58 -07:00
|
|
|
def testPmapInJit(self):
|
|
|
|
@jit
|
|
|
|
def foo(x):
|
2021-09-23 06:33:25 -07:00
|
|
|
@partial(pmap, axis_name='i', devices=jax.devices())
|
2019-08-26 11:22:58 -07:00
|
|
|
def bar(y):
|
|
|
|
return lax.psum(y, 'i')
|
|
|
|
return bar(x)
|
|
|
|
|
2021-09-23 06:33:25 -07:00
|
|
|
ndevices = jax.device_count()
|
2020-05-05 14:59:16 -04:00
|
|
|
ans = foo(jnp.ones((ndevices, 1)))
|
|
|
|
expected = np.ones((ndevices, 1), dtype=jnp.float_) * ndevices
|
2020-06-01 17:19:23 -04:00
|
|
|
self.assertAllClose(ans, expected)
|
2019-08-26 11:22:58 -07:00
|
|
|
|
|
|
|
def testGradBasic(self):
|
2021-09-23 06:33:25 -07:00
|
|
|
@partial(pmap, axis_name='i', devices=jax.devices())
|
2019-08-26 11:22:58 -07:00
|
|
|
def f(x):
|
2020-05-05 14:59:16 -04:00
|
|
|
return jnp.sin(x)
|
2019-08-26 11:22:58 -07:00
|
|
|
|
2021-09-23 06:33:25 -07:00
|
|
|
shape = (jax.device_count(), 4)
|
2020-05-05 14:59:16 -04:00
|
|
|
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
2019-08-26 11:22:58 -07:00
|
|
|
|
2020-05-05 14:59:16 -04:00
|
|
|
ans = grad(lambda x: jnp.sum(jnp.sin(x)))(x)
|
|
|
|
expected = grad(lambda x: jnp.sum(f(x)))(x)
|
2019-08-26 11:22:58 -07:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
2020-02-14 15:45:26 +00:00
|
|
|
def testPmapStaticArgnums(self):
|
|
|
|
@partial(pmap, axis_name='i', static_broadcasted_argnums=1)
|
|
|
|
def f(x, y):
|
2020-10-16 13:11:56 -07:00
|
|
|
return jnp.sin(x + y())
|
2021-09-23 06:33:25 -07:00
|
|
|
shape = (jax.device_count(), 4)
|
2020-05-05 14:59:16 -04:00
|
|
|
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
2020-10-16 13:11:56 -07:00
|
|
|
y = lambda: 3.
|
2020-02-14 15:45:26 +00:00
|
|
|
|
|
|
|
ans = f(x, y)
|
2020-10-16 13:11:56 -07:00
|
|
|
expected = np.sin(x + 3.)
|
2020-02-14 15:45:26 +00:00
|
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
|
Add support for non-zero (but still not-None) out_axes in pmap
Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`),
but its semantics would match the specification of `out_axes=0` (i.e.
all outputs should be stacked along the first axis). This patch makes it
possible to specify non-zero values for out_axes, but more importantly
it lays down the groundwork for `xmap` which will have to use some
extremely similar (if not the same) code paths.
One thing to note is that when I started this implementation I was also
planning to add support for `out_axes=None`, which would allow us to
stop using the `unbroadcast` hack, and most of the code is written with
that in mind. Unfortunately it turned out that the correct
implementation of the transpose rule for maps that do allow unmapped
outputs would require me to pretty much simulate what avals-with-names
is supposed to achieve. Technically replicated outputs should work
today, for as long as the user does not do reverse-mode AD of `pmap`.
But I decided that it's better to just disable them altogether until we
can get the full and correct behavior.
* Implementation details *
This patch is significantly more involved than the one that implemented
general `in_axes` support. That previous one at least had the foundation
of `mapped_invars` which already behaved pretty similarly to general
`in_axes`. From a quick glance one might think that `out_axes` should
behave similarly to `in_axes`, but it turns out that this is not the
case, at least not if we're interested in keeping those primitives
final-style.
** Thunking **
The biggest difficulty with handling `out_axes` in final style
primitives is that we want to treat them as a prefix of the output
pytree, but we don't know the structure of the output pytree until the
user function is evaluated! And the user function is not evaluated until
we've applied all transforms and reached the impl rule! The solution to
this problem is "straightforward": instead of putting `out_axes` as a
primitive parameter, we bundle an `out_axes_thunk` which can only be
called successfully after the wrapped function has been executed. The
thunk returns a list of flat `out_axes`, expanded to the output pytree.
However, the thunking presents us with two problems:
*** Transformations ***
Each transformation that modifies the number of outputs needs to ensure
that the thunk is updated to reflect the new values. To make things
worse a lot of the transforms can learn the number of added outputs
_only after the wrapped function is evaluated_, which leads to the
following "time travel" pattern that can be found in most `Trace`s:
```py
@lu.transformation_with_aux
def compute_output_statistic(*args, **kwargs):
outputs = yield args, kwargs
yield outputs, compute_statistic(outputs)
wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun)
def new_out_axes_thunk():
old_out_axes = params['out_axes_thunk']()
return compute_new_out_axes(old_out_axes(), output_statistic())
primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk))
```
The reason why we have to structure the code this way is that we can
only specify a new `out_axes_thunk` before we bind the primitive, but we
need the outputs of bind to know how to update the `out_axes_thunk`. To
make things worse, the implementation of `bind` is allowed to make a
call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_.
This means that we cannot compute the output statistic in the
implementation of the transformation, but we have to use an extra
`lu.transformation_with_aux` for that (this populates the statistic
store immediately after `wrapped_fun` is evaluated).
The `compute_statistic` function depends on the transform in question.
E.g. in the JVP trace it counts the number of non-zero tangent results.
The situation is of course further complicated when we take
`post_process_map` into account. The new `process_env_traces` now always
sets up this funny time travel trampoline just in case it ends up being
necessary, and `post_process_map` is now expected to return `(outputs,
(todo, out_axes_transform))` instead of just `(outputs, todo)`.
*** Compilation cache ***
Because the `out_axes_thunk`s are now arguments to a _global_
compilation cache (in the form of `lu.cache` decorator on
`parallel_callable`), we have to ensure that they implement `hash` and
`==`. This is what forces us to add some slightly weird helpers such as
`_hashable_function` and `_ignore_elem_list`. The code that uses those
makes an assumption that the output pytree depends deterministically on
the identity of the wrapped function, which I think is in line with
general JAX assumptions. Otherwise the cache would depend on the
identity of the thunk, which changes with every function invocation.
Relaxing the global constraint on the cache (e.g. allowing each
`pmap(f)` instance to have a separate cache) would make this easier too.
* Why final style? *
Now, making the primitives initial-style would remove the necessity for
thunking, because we could have obtained the output pytree right when
the function is wrapped. I assumed there is a good argument for making
`pmap` pretend that it's a final-style primitive, but I'm not sure why
that is? I hope it's something better than just avoiding a single jaxpr
tracing.
2020-11-09 17:23:16 +00:00
|
|
|
def testPmapInAxesBasic(self):
|
2020-11-05 11:54:05 +00:00
|
|
|
@partial(pmap, in_axes=(1, 2))
|
|
|
|
def f(x, y):
|
|
|
|
return jnp.sin(x + y)
|
2021-09-23 06:33:25 -07:00
|
|
|
xshape = (2, jax.device_count(), 4)
|
2020-11-05 11:54:05 +00:00
|
|
|
x = np.arange(prod(xshape)).reshape(xshape)
|
2021-09-23 06:33:25 -07:00
|
|
|
yshape = (2, 4, jax.device_count())
|
2020-11-05 11:54:05 +00:00
|
|
|
y = np.arange(prod(yshape)).reshape(yshape)
|
|
|
|
|
|
|
|
self.assertAllClose(f(x, y),
|
|
|
|
jnp.sin(x.transpose((1, 0, 2)) + y.transpose((2, 0, 1))))
|
|
|
|
|
|
|
|
def testPmapInAxesGrad(self):
|
|
|
|
def f(x, y, z):
|
|
|
|
return jnp.sin(x + y + z)
|
|
|
|
fp = pmap(f, in_axes=(1, 2, None))
|
|
|
|
fv = vmap(f, in_axes=(1, 2, None))
|
2021-09-23 06:33:25 -07:00
|
|
|
xshape = (5, jax.device_count(), 7)
|
2020-11-05 11:54:05 +00:00
|
|
|
x = np.arange(prod(xshape), dtype=np.float32).reshape(xshape)
|
2021-09-23 06:33:25 -07:00
|
|
|
yshape = (5, 7, jax.device_count())
|
2020-11-05 11:54:05 +00:00
|
|
|
y = np.arange(prod(yshape), dtype=np.float32).reshape(yshape)
|
|
|
|
zshape = (5, 7)
|
|
|
|
z = np.arange(prod(zshape), dtype=np.float32).reshape(zshape)
|
|
|
|
|
|
|
|
dx, dy, dz = jax.grad(lambda args: fp(*args).sum())((x, y, z))
|
|
|
|
assert dx.shape == xshape
|
|
|
|
assert dy.shape == yshape
|
|
|
|
assert dz.shape == zshape
|
|
|
|
|
|
|
|
self.assertAllClose(jax.grad(lambda args: fp(*args).sum())((x, y, z)),
|
|
|
|
jax.grad(lambda args: fv(*args).sum())((x, y, z)))
|
2019-08-26 11:22:58 -07:00
|
|
|
|
Add support for non-zero (but still not-None) out_axes in pmap
Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`),
but its semantics would match the specification of `out_axes=0` (i.e.
all outputs should be stacked along the first axis). This patch makes it
possible to specify non-zero values for out_axes, but more importantly
it lays down the groundwork for `xmap` which will have to use some
extremely similar (if not the same) code paths.
One thing to note is that when I started this implementation I was also
planning to add support for `out_axes=None`, which would allow us to
stop using the `unbroadcast` hack, and most of the code is written with
that in mind. Unfortunately it turned out that the correct
implementation of the transpose rule for maps that do allow unmapped
outputs would require me to pretty much simulate what avals-with-names
is supposed to achieve. Technically replicated outputs should work
today, for as long as the user does not do reverse-mode AD of `pmap`.
But I decided that it's better to just disable them altogether until we
can get the full and correct behavior.
* Implementation details *
This patch is significantly more involved than the one that implemented
general `in_axes` support. That previous one at least had the foundation
of `mapped_invars` which already behaved pretty similarly to general
`in_axes`. From a quick glance one might think that `out_axes` should
behave similarly to `in_axes`, but it turns out that this is not the
case, at least not if we're interested in keeping those primitives
final-style.
** Thunking **
The biggest difficulty with handling `out_axes` in final style
primitives is that we want to treat them as a prefix of the output
pytree, but we don't know the structure of the output pytree until the
user function is evaluated! And the user function is not evaluated until
we've applied all transforms and reached the impl rule! The solution to
this problem is "straightforward": instead of putting `out_axes` as a
primitive parameter, we bundle an `out_axes_thunk` which can only be
called successfully after the wrapped function has been executed. The
thunk returns a list of flat `out_axes`, expanded to the output pytree.
However, the thunking presents us with two problems:
*** Transformations ***
Each transformation that modifies the number of outputs needs to ensure
that the thunk is updated to reflect the new values. To make things
worse a lot of the transforms can learn the number of added outputs
_only after the wrapped function is evaluated_, which leads to the
following "time travel" pattern that can be found in most `Trace`s:
```py
@lu.transformation_with_aux
def compute_output_statistic(*args, **kwargs):
outputs = yield args, kwargs
yield outputs, compute_statistic(outputs)
wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun)
def new_out_axes_thunk():
old_out_axes = params['out_axes_thunk']()
return compute_new_out_axes(old_out_axes(), output_statistic())
primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk))
```
The reason why we have to structure the code this way is that we can
only specify a new `out_axes_thunk` before we bind the primitive, but we
need the outputs of bind to know how to update the `out_axes_thunk`. To
make things worse, the implementation of `bind` is allowed to make a
call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_.
This means that we cannot compute the output statistic in the
implementation of the transformation, but we have to use an extra
`lu.transformation_with_aux` for that (this populates the statistic
store immediately after `wrapped_fun` is evaluated).
The `compute_statistic` function depends on the transform in question.
E.g. in the JVP trace it counts the number of non-zero tangent results.
The situation is of course further complicated when we take
`post_process_map` into account. The new `process_env_traces` now always
sets up this funny time travel trampoline just in case it ends up being
necessary, and `post_process_map` is now expected to return `(outputs,
(todo, out_axes_transform))` instead of just `(outputs, todo)`.
*** Compilation cache ***
Because the `out_axes_thunk`s are now arguments to a _global_
compilation cache (in the form of `lu.cache` decorator on
`parallel_callable`), we have to ensure that they implement `hash` and
`==`. This is what forces us to add some slightly weird helpers such as
`_hashable_function` and `_ignore_elem_list`. The code that uses those
makes an assumption that the output pytree depends deterministically on
the identity of the wrapped function, which I think is in line with
general JAX assumptions. Otherwise the cache would depend on the
identity of the thunk, which changes with every function invocation.
Relaxing the global constraint on the cache (e.g. allowing each
`pmap(f)` instance to have a separate cache) would make this easier too.
* Why final style? *
Now, making the primitives initial-style would remove the necessity for
thunking, because we could have obtained the output pytree right when
the function is wrapped. I assumed there is a good argument for making
`pmap` pretend that it's a final-style primitive, but I'm not sure why
that is? I hope it's something better than just avoiding a single jaxpr
tracing.
2020-11-09 17:23:16 +00:00
|
|
|
def testPmapOutAxesBasic(self):
|
|
|
|
@partial(pmap, in_axes=(1, None), out_axes=(2, None))
|
|
|
|
def f(x, y):
|
|
|
|
return jnp.sin(x + y), y * 2
|
2021-09-23 06:33:25 -07:00
|
|
|
xshape = (2, jax.device_count(), 4)
|
Add support for non-zero (but still not-None) out_axes in pmap
Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`),
but its semantics would match the specification of `out_axes=0` (i.e.
all outputs should be stacked along the first axis). This patch makes it
possible to specify non-zero values for out_axes, but more importantly
it lays down the groundwork for `xmap` which will have to use some
extremely similar (if not the same) code paths.
One thing to note is that when I started this implementation I was also
planning to add support for `out_axes=None`, which would allow us to
stop using the `unbroadcast` hack, and most of the code is written with
that in mind. Unfortunately it turned out that the correct
implementation of the transpose rule for maps that do allow unmapped
outputs would require me to pretty much simulate what avals-with-names
is supposed to achieve. Technically replicated outputs should work
today, for as long as the user does not do reverse-mode AD of `pmap`.
But I decided that it's better to just disable them altogether until we
can get the full and correct behavior.
* Implementation details *
This patch is significantly more involved than the one that implemented
general `in_axes` support. That previous one at least had the foundation
of `mapped_invars` which already behaved pretty similarly to general
`in_axes`. From a quick glance one might think that `out_axes` should
behave similarly to `in_axes`, but it turns out that this is not the
case, at least not if we're interested in keeping those primitives
final-style.
** Thunking **
The biggest difficulty with handling `out_axes` in final style
primitives is that we want to treat them as a prefix of the output
pytree, but we don't know the structure of the output pytree until the
user function is evaluated! And the user function is not evaluated until
we've applied all transforms and reached the impl rule! The solution to
this problem is "straightforward": instead of putting `out_axes` as a
primitive parameter, we bundle an `out_axes_thunk` which can only be
called successfully after the wrapped function has been executed. The
thunk returns a list of flat `out_axes`, expanded to the output pytree.
However, the thunking presents us with two problems:
*** Transformations ***
Each transformation that modifies the number of outputs needs to ensure
that the thunk is updated to reflect the new values. To make things
worse a lot of the transforms can learn the number of added outputs
_only after the wrapped function is evaluated_, which leads to the
following "time travel" pattern that can be found in most `Trace`s:
```py
@lu.transformation_with_aux
def compute_output_statistic(*args, **kwargs):
outputs = yield args, kwargs
yield outputs, compute_statistic(outputs)
wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun)
def new_out_axes_thunk():
old_out_axes = params['out_axes_thunk']()
return compute_new_out_axes(old_out_axes(), output_statistic())
primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk))
```
The reason why we have to structure the code this way is that we can
only specify a new `out_axes_thunk` before we bind the primitive, but we
need the outputs of bind to know how to update the `out_axes_thunk`. To
make things worse, the implementation of `bind` is allowed to make a
call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_.
This means that we cannot compute the output statistic in the
implementation of the transformation, but we have to use an extra
`lu.transformation_with_aux` for that (this populates the statistic
store immediately after `wrapped_fun` is evaluated).
The `compute_statistic` function depends on the transform in question.
E.g. in the JVP trace it counts the number of non-zero tangent results.
The situation is of course further complicated when we take
`post_process_map` into account. The new `process_env_traces` now always
sets up this funny time travel trampoline just in case it ends up being
necessary, and `post_process_map` is now expected to return `(outputs,
(todo, out_axes_transform))` instead of just `(outputs, todo)`.
*** Compilation cache ***
Because the `out_axes_thunk`s are now arguments to a _global_
compilation cache (in the form of `lu.cache` decorator on
`parallel_callable`), we have to ensure that they implement `hash` and
`==`. This is what forces us to add some slightly weird helpers such as
`_hashable_function` and `_ignore_elem_list`. The code that uses those
makes an assumption that the output pytree depends deterministically on
the identity of the wrapped function, which I think is in line with
general JAX assumptions. Otherwise the cache would depend on the
identity of the thunk, which changes with every function invocation.
Relaxing the global constraint on the cache (e.g. allowing each
`pmap(f)` instance to have a separate cache) would make this easier too.
* Why final style? *
Now, making the primitives initial-style would remove the necessity for
thunking, because we could have obtained the output pytree right when
the function is wrapped. I assumed there is a good argument for making
`pmap` pretend that it's a final-style primitive, but I'm not sure why
that is? I hope it's something better than just avoiding a single jaxpr
tracing.
2020-11-09 17:23:16 +00:00
|
|
|
x = np.arange(prod(xshape)).reshape(xshape)
|
|
|
|
yshape = (2, 4)
|
|
|
|
y = np.arange(prod(yshape)).reshape(yshape)
|
|
|
|
|
|
|
|
self.assertAllClose(f(x, y),
|
|
|
|
(jnp.sin(x.transpose((1, 0, 2)) + y).transpose((1, 2, 0)), y * 2))
|
|
|
|
|
2021-04-21 11:49:21 +01:00
|
|
|
def testPmapDictOutAxes(self):
|
|
|
|
# see issue #6410
|
|
|
|
@partial(pmap, out_axes={'a': 0})
|
|
|
|
def f(x):
|
|
|
|
return {'a': x}
|
2021-09-23 06:33:25 -07:00
|
|
|
device_count = jax.device_count()
|
2021-04-21 11:49:21 +01:00
|
|
|
x = jnp.arange(device_count)
|
2022-04-01 14:51:54 -07:00
|
|
|
tree_util.tree_map(self.assertAllClose, f(x), {'a': x})
|
2021-04-21 11:49:21 +01:00
|
|
|
|
Add support for non-zero (but still not-None) out_axes in pmap
Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`),
but its semantics would match the specification of `out_axes=0` (i.e.
all outputs should be stacked along the first axis). This patch makes it
possible to specify non-zero values for out_axes, but more importantly
it lays down the groundwork for `xmap` which will have to use some
extremely similar (if not the same) code paths.
One thing to note is that when I started this implementation I was also
planning to add support for `out_axes=None`, which would allow us to
stop using the `unbroadcast` hack, and most of the code is written with
that in mind. Unfortunately it turned out that the correct
implementation of the transpose rule for maps that do allow unmapped
outputs would require me to pretty much simulate what avals-with-names
is supposed to achieve. Technically replicated outputs should work
today, for as long as the user does not do reverse-mode AD of `pmap`.
But I decided that it's better to just disable them altogether until we
can get the full and correct behavior.
* Implementation details *
This patch is significantly more involved than the one that implemented
general `in_axes` support. That previous one at least had the foundation
of `mapped_invars` which already behaved pretty similarly to general
`in_axes`. From a quick glance one might think that `out_axes` should
behave similarly to `in_axes`, but it turns out that this is not the
case, at least not if we're interested in keeping those primitives
final-style.
** Thunking **
The biggest difficulty with handling `out_axes` in final style
primitives is that we want to treat them as a prefix of the output
pytree, but we don't know the structure of the output pytree until the
user function is evaluated! And the user function is not evaluated until
we've applied all transforms and reached the impl rule! The solution to
this problem is "straightforward": instead of putting `out_axes` as a
primitive parameter, we bundle an `out_axes_thunk` which can only be
called successfully after the wrapped function has been executed. The
thunk returns a list of flat `out_axes`, expanded to the output pytree.
However, the thunking presents us with two problems:
*** Transformations ***
Each transformation that modifies the number of outputs needs to ensure
that the thunk is updated to reflect the new values. To make things
worse a lot of the transforms can learn the number of added outputs
_only after the wrapped function is evaluated_, which leads to the
following "time travel" pattern that can be found in most `Trace`s:
```py
@lu.transformation_with_aux
def compute_output_statistic(*args, **kwargs):
outputs = yield args, kwargs
yield outputs, compute_statistic(outputs)
wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun)
def new_out_axes_thunk():
old_out_axes = params['out_axes_thunk']()
return compute_new_out_axes(old_out_axes(), output_statistic())
primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk))
```
The reason why we have to structure the code this way is that we can
only specify a new `out_axes_thunk` before we bind the primitive, but we
need the outputs of bind to know how to update the `out_axes_thunk`. To
make things worse, the implementation of `bind` is allowed to make a
call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_.
This means that we cannot compute the output statistic in the
implementation of the transformation, but we have to use an extra
`lu.transformation_with_aux` for that (this populates the statistic
store immediately after `wrapped_fun` is evaluated).
The `compute_statistic` function depends on the transform in question.
E.g. in the JVP trace it counts the number of non-zero tangent results.
The situation is of course further complicated when we take
`post_process_map` into account. The new `process_env_traces` now always
sets up this funny time travel trampoline just in case it ends up being
necessary, and `post_process_map` is now expected to return `(outputs,
(todo, out_axes_transform))` instead of just `(outputs, todo)`.
*** Compilation cache ***
Because the `out_axes_thunk`s are now arguments to a _global_
compilation cache (in the form of `lu.cache` decorator on
`parallel_callable`), we have to ensure that they implement `hash` and
`==`. This is what forces us to add some slightly weird helpers such as
`_hashable_function` and `_ignore_elem_list`. The code that uses those
makes an assumption that the output pytree depends deterministically on
the identity of the wrapped function, which I think is in line with
general JAX assumptions. Otherwise the cache would depend on the
identity of the thunk, which changes with every function invocation.
Relaxing the global constraint on the cache (e.g. allowing each
`pmap(f)` instance to have a separate cache) would make this easier too.
* Why final style? *
Now, making the primitives initial-style would remove the necessity for
thunking, because we could have obtained the output pytree right when
the function is wrapped. I assumed there is a good argument for making
`pmap` pretend that it's a final-style primitive, but I'm not sure why
that is? I hope it's something better than just avoiding a single jaxpr
tracing.
2020-11-09 17:23:16 +00:00
|
|
|
@parameterized.named_parameters(jtu.cases_from_list(
|
|
|
|
{"testcase_name": f"_{in_axes}_{out_axes}",
|
|
|
|
"in_axes": in_axes, "out_axes": out_axes}
|
|
|
|
for in_axes in all_bdims((3, 4), (3, 1), (1, 4), pmap=True)
|
|
|
|
for out_axes in out_bdims((3, 4), True)
|
|
|
|
))
|
|
|
|
def testPmapAllAxesGrad(self, in_axes, out_axes):
|
|
|
|
def f(x, y, z):
|
|
|
|
return jnp.sin(x + y) * z
|
|
|
|
|
2021-09-23 06:33:25 -07:00
|
|
|
pmapped_size = jax.device_count()
|
Add support for non-zero (but still not-None) out_axes in pmap
Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`),
but its semantics would match the specification of `out_axes=0` (i.e.
all outputs should be stacked along the first axis). This patch makes it
possible to specify non-zero values for out_axes, but more importantly
it lays down the groundwork for `xmap` which will have to use some
extremely similar (if not the same) code paths.
One thing to note is that when I started this implementation I was also
planning to add support for `out_axes=None`, which would allow us to
stop using the `unbroadcast` hack, and most of the code is written with
that in mind. Unfortunately it turned out that the correct
implementation of the transpose rule for maps that do allow unmapped
outputs would require me to pretty much simulate what avals-with-names
is supposed to achieve. Technically replicated outputs should work
today, for as long as the user does not do reverse-mode AD of `pmap`.
But I decided that it's better to just disable them altogether until we
can get the full and correct behavior.
* Implementation details *
This patch is significantly more involved than the one that implemented
general `in_axes` support. That previous one at least had the foundation
of `mapped_invars` which already behaved pretty similarly to general
`in_axes`. From a quick glance one might think that `out_axes` should
behave similarly to `in_axes`, but it turns out that this is not the
case, at least not if we're interested in keeping those primitives
final-style.
** Thunking **
The biggest difficulty with handling `out_axes` in final style
primitives is that we want to treat them as a prefix of the output
pytree, but we don't know the structure of the output pytree until the
user function is evaluated! And the user function is not evaluated until
we've applied all transforms and reached the impl rule! The solution to
this problem is "straightforward": instead of putting `out_axes` as a
primitive parameter, we bundle an `out_axes_thunk` which can only be
called successfully after the wrapped function has been executed. The
thunk returns a list of flat `out_axes`, expanded to the output pytree.
However, the thunking presents us with two problems:
*** Transformations ***
Each transformation that modifies the number of outputs needs to ensure
that the thunk is updated to reflect the new values. To make things
worse a lot of the transforms can learn the number of added outputs
_only after the wrapped function is evaluated_, which leads to the
following "time travel" pattern that can be found in most `Trace`s:
```py
@lu.transformation_with_aux
def compute_output_statistic(*args, **kwargs):
outputs = yield args, kwargs
yield outputs, compute_statistic(outputs)
wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun)
def new_out_axes_thunk():
old_out_axes = params['out_axes_thunk']()
return compute_new_out_axes(old_out_axes(), output_statistic())
primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk))
```
The reason why we have to structure the code this way is that we can
only specify a new `out_axes_thunk` before we bind the primitive, but we
need the outputs of bind to know how to update the `out_axes_thunk`. To
make things worse, the implementation of `bind` is allowed to make a
call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_.
This means that we cannot compute the output statistic in the
implementation of the transformation, but we have to use an extra
`lu.transformation_with_aux` for that (this populates the statistic
store immediately after `wrapped_fun` is evaluated).
The `compute_statistic` function depends on the transform in question.
E.g. in the JVP trace it counts the number of non-zero tangent results.
The situation is of course further complicated when we take
`post_process_map` into account. The new `process_env_traces` now always
sets up this funny time travel trampoline just in case it ends up being
necessary, and `post_process_map` is now expected to return `(outputs,
(todo, out_axes_transform))` instead of just `(outputs, todo)`.
*** Compilation cache ***
Because the `out_axes_thunk`s are now arguments to a _global_
compilation cache (in the form of `lu.cache` decorator on
`parallel_callable`), we have to ensure that they implement `hash` and
`==`. This is what forces us to add some slightly weird helpers such as
`_hashable_function` and `_ignore_elem_list`. The code that uses those
makes an assumption that the output pytree depends deterministically on
the identity of the wrapped function, which I think is in line with
general JAX assumptions. Otherwise the cache would depend on the
identity of the thunk, which changes with every function invocation.
Relaxing the global constraint on the cache (e.g. allowing each
`pmap(f)` instance to have a separate cache) would make this easier too.
* Why final style? *
Now, making the primitives initial-style would remove the necessity for
thunking, because we could have obtained the output pytree right when
the function is wrapped. I assumed there is a good argument for making
`pmap` pretend that it's a final-style primitive, but I'm not sure why
that is? I hope it's something better than just avoiding a single jaxpr
tracing.
2020-11-09 17:23:16 +00:00
|
|
|
mapped_shapes = [(3, 4), (3, 1), (1, 4)]
|
|
|
|
arg_shapes = map(partial(add_bdim, pmapped_size), in_axes, mapped_shapes)
|
|
|
|
rng = jtu.rand_default(self.rng())
|
|
|
|
args = [rng(shape, jnp.float64) for shape in arg_shapes]
|
|
|
|
jtu.check_grads(pmap(f, in_axes=in_axes, out_axes=out_axes), args,
|
|
|
|
order=2, atol=2e-2, rtol=2e-2, eps=1e-3)
|
|
|
|
|
|
|
|
def testPmapPostProcess(self):
|
|
|
|
def mk_case(map_fun):
|
|
|
|
def f(x, y):
|
|
|
|
# NOTE: Map doesn't have any arguments we differentiate wrt
|
|
|
|
@partial(map_fun, in_axes=1, out_axes=2)
|
|
|
|
def h(y):
|
|
|
|
return jnp.sin(x + y)
|
|
|
|
return h(y).sum()
|
|
|
|
return f
|
|
|
|
|
|
|
|
xshape = (5, 7)
|
|
|
|
x = np.arange(prod(xshape), dtype=np.float32).reshape(xshape)
|
2021-09-23 06:33:25 -07:00
|
|
|
yshape = (5, jax.device_count(), 7)
|
Add support for non-zero (but still not-None) out_axes in pmap
Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`),
but its semantics would match the specification of `out_axes=0` (i.e.
all outputs should be stacked along the first axis). This patch makes it
possible to specify non-zero values for out_axes, but more importantly
it lays down the groundwork for `xmap` which will have to use some
extremely similar (if not the same) code paths.
One thing to note is that when I started this implementation I was also
planning to add support for `out_axes=None`, which would allow us to
stop using the `unbroadcast` hack, and most of the code is written with
that in mind. Unfortunately it turned out that the correct
implementation of the transpose rule for maps that do allow unmapped
outputs would require me to pretty much simulate what avals-with-names
is supposed to achieve. Technically replicated outputs should work
today, for as long as the user does not do reverse-mode AD of `pmap`.
But I decided that it's better to just disable them altogether until we
can get the full and correct behavior.
* Implementation details *
This patch is significantly more involved than the one that implemented
general `in_axes` support. That previous one at least had the foundation
of `mapped_invars` which already behaved pretty similarly to general
`in_axes`. From a quick glance one might think that `out_axes` should
behave similarly to `in_axes`, but it turns out that this is not the
case, at least not if we're interested in keeping those primitives
final-style.
** Thunking **
The biggest difficulty with handling `out_axes` in final style
primitives is that we want to treat them as a prefix of the output
pytree, but we don't know the structure of the output pytree until the
user function is evaluated! And the user function is not evaluated until
we've applied all transforms and reached the impl rule! The solution to
this problem is "straightforward": instead of putting `out_axes` as a
primitive parameter, we bundle an `out_axes_thunk` which can only be
called successfully after the wrapped function has been executed. The
thunk returns a list of flat `out_axes`, expanded to the output pytree.
However, the thunking presents us with two problems:
*** Transformations ***
Each transformation that modifies the number of outputs needs to ensure
that the thunk is updated to reflect the new values. To make things
worse a lot of the transforms can learn the number of added outputs
_only after the wrapped function is evaluated_, which leads to the
following "time travel" pattern that can be found in most `Trace`s:
```py
@lu.transformation_with_aux
def compute_output_statistic(*args, **kwargs):
outputs = yield args, kwargs
yield outputs, compute_statistic(outputs)
wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun)
def new_out_axes_thunk():
old_out_axes = params['out_axes_thunk']()
return compute_new_out_axes(old_out_axes(), output_statistic())
primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk))
```
The reason why we have to structure the code this way is that we can
only specify a new `out_axes_thunk` before we bind the primitive, but we
need the outputs of bind to know how to update the `out_axes_thunk`. To
make things worse, the implementation of `bind` is allowed to make a
call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_.
This means that we cannot compute the output statistic in the
implementation of the transformation, but we have to use an extra
`lu.transformation_with_aux` for that (this populates the statistic
store immediately after `wrapped_fun` is evaluated).
The `compute_statistic` function depends on the transform in question.
E.g. in the JVP trace it counts the number of non-zero tangent results.
The situation is of course further complicated when we take
`post_process_map` into account. The new `process_env_traces` now always
sets up this funny time travel trampoline just in case it ends up being
necessary, and `post_process_map` is now expected to return `(outputs,
(todo, out_axes_transform))` instead of just `(outputs, todo)`.
*** Compilation cache ***
Because the `out_axes_thunk`s are now arguments to a _global_
compilation cache (in the form of `lu.cache` decorator on
`parallel_callable`), we have to ensure that they implement `hash` and
`==`. This is what forces us to add some slightly weird helpers such as
`_hashable_function` and `_ignore_elem_list`. The code that uses those
makes an assumption that the output pytree depends deterministically on
the identity of the wrapped function, which I think is in line with
general JAX assumptions. Otherwise the cache would depend on the
identity of the thunk, which changes with every function invocation.
Relaxing the global constraint on the cache (e.g. allowing each
`pmap(f)` instance to have a separate cache) would make this easier too.
* Why final style? *
Now, making the primitives initial-style would remove the necessity for
thunking, because we could have obtained the output pytree right when
the function is wrapped. I assumed there is a good argument for making
`pmap` pretend that it's a final-style primitive, but I'm not sure why
that is? I hope it's something better than just avoiding a single jaxpr
tracing.
2020-11-09 17:23:16 +00:00
|
|
|
y = np.arange(prod(yshape), dtype=np.float32).reshape(yshape)
|
|
|
|
self.assertAllClose(jax.grad(mk_case(pmap))(x, y),
|
|
|
|
jax.grad(mk_case(vmap))(x, y))
|
|
|
|
|
|
|
|
|
2020-04-23 16:01:05 -07:00
|
|
|
class ShardedDeviceArrayTest(jtu.JaxTestCase):
|
|
|
|
|
|
|
|
def testThreadsafeIndexing(self):
|
|
|
|
# NOTE(skye): I picked these values to be big enough to cause interesting
|
|
|
|
# execution overlap, but small enough to not use too much memory. YMMV.
|
|
|
|
shape = (8, 8000, 1000)
|
|
|
|
|
|
|
|
if jax.device_count() < shape[0]:
|
|
|
|
raise SkipTest(f"requires {shape[0]} devices")
|
|
|
|
|
2020-08-18 10:17:38 -07:00
|
|
|
x = jnp.arange(prod(shape)).reshape(shape)
|
2020-04-23 16:01:05 -07:00
|
|
|
sharded_x = pmap(lambda x: x)(x)
|
|
|
|
|
|
|
|
num_threads = 10
|
|
|
|
futures = []
|
|
|
|
expected = []
|
|
|
|
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
|
|
|
for i in range(num_threads):
|
|
|
|
idx = i % shape[0]
|
|
|
|
# Mix together different kinds of indices
|
|
|
|
if i % 2 == 0:
|
|
|
|
idx = slice(idx, idx + 1)
|
2020-04-28 16:02:30 -07:00
|
|
|
# Use the "kwarg trick" to work around late-binding closures. See
|
|
|
|
# https://docs.python-guide.org/writing/gotchas/#late-binding-closures.
|
2020-04-23 16:01:05 -07:00
|
|
|
futures.append(executor.submit(
|
2020-04-28 16:02:30 -07:00
|
|
|
lambda idx=idx: [sharded_x[idx] for _ in range(10)][0]))
|
2020-04-23 16:01:05 -07:00
|
|
|
expected.append(x[idx])
|
|
|
|
actual = [f.result() for f in futures]
|
|
|
|
self.assertAllClose(actual, expected, check_dtypes=False)
|
|
|
|
|
2020-11-06 12:55:17 +00:00
|
|
|
def testNoCopyIndexing1D(self):
|
|
|
|
shape = (8, 4)
|
|
|
|
|
|
|
|
if jax.device_count() < shape[0]:
|
|
|
|
raise SkipTest(f"requires {shape[0]} devices")
|
|
|
|
|
|
|
|
x = jnp.arange(prod(shape)).reshape(shape)
|
|
|
|
sharded_x = pmap(lambda x: x)(x)
|
|
|
|
self.assertIsNone(sharded_x._npy_value)
|
|
|
|
for i in range(8):
|
2021-11-22 08:22:10 -08:00
|
|
|
self.assertIsInstance(sharded_x[i], device_array.DeviceArray)
|
2020-11-06 12:55:17 +00:00
|
|
|
self.assertIsNone(sharded_x._npy_value)
|
|
|
|
|
2020-12-04 12:53:36 -08:00
|
|
|
def test_device_put_sharded_array(self):
|
|
|
|
devices = jax.local_devices()
|
|
|
|
n_devices = len(devices)
|
|
|
|
x = [np.arange(i, i + 4) for i in range(n_devices)]
|
|
|
|
y = jax.device_put_sharded(x, devices)
|
|
|
|
self.assertIsInstance(y, pxla.ShardedDeviceArray)
|
|
|
|
self.assertEqual(len(y.device_buffers), len(devices))
|
|
|
|
self.assertTrue(all(b.device() == d for b, d in zip(y.device_buffers, devices)))
|
2020-12-04 17:07:23 -08:00
|
|
|
self.assertArraysEqual(y, jnp.stack(x))
|
2020-12-04 12:53:36 -08:00
|
|
|
|
|
|
|
def test_device_put_sharded_pytree(self):
|
|
|
|
devices = jax.local_devices()
|
|
|
|
n_devices = len(devices)
|
|
|
|
x = [(i, np.arange(i, i + 4)) for i in range(n_devices)]
|
|
|
|
y1, y2 = jax.device_put_sharded(x, devices)
|
|
|
|
self.assertIsInstance(y1, pxla.ShardedDeviceArray)
|
2020-12-04 17:07:23 -08:00
|
|
|
self.assertArraysEqual(y1, jnp.array([a for a, _ in x]))
|
2020-12-04 12:53:36 -08:00
|
|
|
self.assertTrue(all(b.device() == d for b, d in zip(y1.device_buffers, devices)))
|
|
|
|
self.assertIsInstance(y2, pxla.ShardedDeviceArray)
|
2020-12-04 17:07:23 -08:00
|
|
|
self.assertArraysEqual(y2, jnp.vstack([b for _, b in x]))
|
2020-12-04 12:53:36 -08:00
|
|
|
self.assertTrue(all(b.device() == d for b, d in zip(y2.device_buffers, devices)))
|
|
|
|
|
|
|
|
def test_device_put_replicated_array(self):
|
|
|
|
devices = jax.local_devices()
|
|
|
|
x = np.arange(1, 5)
|
|
|
|
y = jax.device_put_replicated(x, devices)
|
|
|
|
self.assertIsInstance(y, pxla.ShardedDeviceArray)
|
|
|
|
self.assertEqual(len(y.device_buffers), len(devices))
|
|
|
|
self.assertTrue(all(b.device() == d for b, d in zip(y.device_buffers, devices)))
|
2020-12-04 17:07:23 -08:00
|
|
|
self.assertArraysEqual(y, np.stack([x for _ in devices]))
|
2020-12-04 12:53:36 -08:00
|
|
|
|
|
|
|
def test_device_put_replicated_pytree(self):
|
|
|
|
devices = jax.local_devices()
|
|
|
|
xs = {'a': np.arange(1, 5), 'b': np.arange(3)}
|
|
|
|
ys = jax.device_put_replicated(xs, devices)
|
|
|
|
self.assertIsInstance(ys, dict)
|
|
|
|
y1, y2 = ys['a'], ys['b']
|
|
|
|
|
|
|
|
self.assertIsInstance(y1, pxla.ShardedDeviceArray)
|
|
|
|
self.assertEqual(len(y1.device_buffers), len(devices))
|
|
|
|
self.assertTrue(all(b.device() == d for b, d in zip(y1.device_buffers, devices)))
|
2020-12-04 17:07:23 -08:00
|
|
|
self.assertArraysEqual(y1, np.stack([xs['a'] for _ in devices]))
|
2020-12-04 12:53:36 -08:00
|
|
|
|
|
|
|
self.assertIsInstance(y2, pxla.ShardedDeviceArray)
|
|
|
|
self.assertEqual(len(y2.device_buffers), len(devices))
|
|
|
|
self.assertTrue(all(b.device() == d for b, d in zip(y2.device_buffers, devices)))
|
2020-12-04 17:07:23 -08:00
|
|
|
self.assertArraysEqual(y2, np.stack([xs['b'] for _ in devices]))
|
2020-12-04 12:53:36 -08:00
|
|
|
|
2020-12-04 21:25:51 -08:00
|
|
|
def test_repr(self):
|
|
|
|
x = jax.device_put_replicated(1, jax.devices())
|
|
|
|
self.assertStartsWith(repr(x), 'ShardedDeviceArray')
|
|
|
|
|
2021-08-10 11:10:34 -07:00
|
|
|
def test_delete_is_idempotent(self):
|
|
|
|
x = jax.device_put_replicated(1, jax.devices())
|
|
|
|
x.delete()
|
|
|
|
x.delete()
|
|
|
|
|
|
|
|
with self.assertRaisesRegex(ValueError,
|
|
|
|
'ShardedDeviceArray has been deleted.'):
|
|
|
|
_ = x[0]
|
|
|
|
|
2020-04-23 16:01:05 -07:00
|
|
|
|
2020-04-15 12:43:55 -07:00
|
|
|
class SpecToIndicesTest(jtu.JaxTestCase):
|
|
|
|
|
|
|
|
def testShardsPerAxis(self):
|
|
|
|
shape = (4, 8)
|
2021-01-16 17:57:39 +01:00
|
|
|
spec = pxla.ShardingSpec(sharding=map(pxla.Chunked, ([2], [2])),
|
2020-11-06 12:55:17 +00:00
|
|
|
mesh_mapping=map(pxla.ShardedAxis, (0, 1)))
|
2020-04-15 12:43:55 -07:00
|
|
|
self.assertEqual(pxla.spec_to_indices(shape, spec),
|
|
|
|
((slice(0,2), slice(0,4)),
|
|
|
|
(slice(0,2), slice(4,8)),
|
|
|
|
(slice(2,4), slice(0,4)),
|
|
|
|
(slice(2,4), slice(4,8))))
|
|
|
|
|
2020-11-06 12:55:17 +00:00
|
|
|
def testShardedAxisPermutation(self):
|
|
|
|
shape = (4, 8)
|
2021-01-16 17:57:39 +01:00
|
|
|
spec = pxla.ShardingSpec(sharding=map(pxla.Chunked, ([2], [2])),
|
2020-11-06 12:55:17 +00:00
|
|
|
mesh_mapping=map(pxla.ShardedAxis, (1, 0)))
|
|
|
|
self.assertEqual(pxla.spec_to_indices(shape, spec),
|
|
|
|
((slice(0,2), slice(0,4)),
|
|
|
|
(slice(2,4), slice(0,4)),
|
|
|
|
(slice(0,2), slice(4,8)),
|
|
|
|
(slice(2,4), slice(4,8))))
|
|
|
|
|
|
|
|
def testShardedAxisPermutationAndReplication(self):
|
|
|
|
shape = (4, 8)
|
2021-01-16 17:57:39 +01:00
|
|
|
spec = pxla.ShardingSpec(sharding=map(pxla.Chunked, ([2], [2])),
|
2020-11-06 12:55:17 +00:00
|
|
|
mesh_mapping=(pxla.Replicated(2),
|
|
|
|
pxla.ShardedAxis(1),
|
|
|
|
pxla.ShardedAxis(0)))
|
|
|
|
self.assertEqual(pxla.spec_to_indices(shape, spec),
|
|
|
|
((slice(0,2), slice(0,4)),
|
|
|
|
(slice(2,4), slice(0,4)),
|
|
|
|
(slice(0,2), slice(4,8)),
|
|
|
|
(slice(2,4), slice(4,8))) * 2)
|
|
|
|
|
2020-04-15 12:43:55 -07:00
|
|
|
def testUnshardedAxis(self):
|
|
|
|
shape = (4, 8)
|
2021-01-16 17:57:39 +01:00
|
|
|
spec = pxla.ShardingSpec(sharding=(pxla.Chunked([2]), pxla.NoSharding()),
|
2020-11-06 12:55:17 +00:00
|
|
|
mesh_mapping=(pxla.ShardedAxis(0),))
|
2020-04-15 12:43:55 -07:00
|
|
|
self.assertEqual(pxla.spec_to_indices(shape, spec),
|
2020-11-06 12:55:17 +00:00
|
|
|
((slice(0,2), slice(None)),
|
|
|
|
(slice(2,4), slice(None))))
|
2020-04-15 12:43:55 -07:00
|
|
|
|
|
|
|
def testNoSharding(self):
|
2020-05-06 10:19:28 -07:00
|
|
|
shape = (4, 8)
|
2021-01-21 01:14:56 -08:00
|
|
|
spec = pxla.ShardingSpec(sharding=(pxla.NoSharding(), pxla.NoSharding()),
|
2020-11-06 12:55:17 +00:00
|
|
|
mesh_mapping=())
|
2020-04-15 12:43:55 -07:00
|
|
|
self.assertEqual(pxla.spec_to_indices(shape, spec),
|
2020-11-06 12:55:17 +00:00
|
|
|
((slice(None), slice(None)),))
|
2020-04-15 12:43:55 -07:00
|
|
|
|
|
|
|
def testUnmaterializedAxis(self):
|
|
|
|
shape = (4, 8)
|
2021-01-21 01:14:56 -08:00
|
|
|
spec = pxla.ShardingSpec(sharding=(pxla.Unstacked(4), pxla.NoSharding()),
|
2020-11-06 12:55:17 +00:00
|
|
|
mesh_mapping=(pxla.ShardedAxis(0),))
|
2020-04-15 12:43:55 -07:00
|
|
|
self.assertEqual(pxla.spec_to_indices(shape, spec),
|
2020-11-06 12:55:17 +00:00
|
|
|
((0, slice(None)),
|
|
|
|
(1, slice(None)),
|
|
|
|
(2, slice(None)),
|
|
|
|
(3, slice(None))))
|
2020-04-15 12:43:55 -07:00
|
|
|
|
|
|
|
shape = (2, 2)
|
2021-01-21 01:14:56 -08:00
|
|
|
spec = pxla.ShardingSpec(sharding=(pxla.NoSharding(), pxla.Unstacked(2)),
|
2020-11-06 12:55:17 +00:00
|
|
|
mesh_mapping=(pxla.ShardedAxis(0),))
|
2020-04-15 12:43:55 -07:00
|
|
|
self.assertEqual(pxla.spec_to_indices(shape, spec),
|
|
|
|
((slice(None), 0),
|
|
|
|
(slice(None), 1)))
|
|
|
|
|
2020-06-01 16:50:22 -07:00
|
|
|
def testReplicationAfterUnsharded(self):
|
2020-04-15 12:43:55 -07:00
|
|
|
shape = (2, 8)
|
2021-01-21 01:14:56 -08:00
|
|
|
spec = pxla.ShardingSpec(sharding=(pxla.Unstacked(2), pxla.NoSharding()),
|
2020-11-06 12:55:17 +00:00
|
|
|
mesh_mapping=(pxla.ShardedAxis(0), pxla.Replicated(3)))
|
2020-04-15 12:43:55 -07:00
|
|
|
self.assertEqual(pxla.spec_to_indices(shape, spec),
|
2020-11-06 12:55:17 +00:00
|
|
|
tuple([(0, slice(None))] * 3 + [(1, slice(None))] * 3))
|
2020-04-15 12:43:55 -07:00
|
|
|
|
2020-06-01 16:50:22 -07:00
|
|
|
def testReplicationPosition2(self):
|
|
|
|
shape = (2, 8)
|
2021-01-16 17:57:39 +01:00
|
|
|
spec = pxla.ShardingSpec(sharding=(pxla.Unstacked(2), pxla.Chunked([2])),
|
2020-11-06 12:55:17 +00:00
|
|
|
mesh_mapping=(pxla.ShardedAxis(0), pxla.ShardedAxis(1), pxla.Replicated(3)))
|
2020-06-01 16:50:22 -07:00
|
|
|
self.assertEqual(pxla.spec_to_indices(shape, spec),
|
|
|
|
((0, slice(0, 4)), (0, slice(0, 4)), (0, slice(0, 4)),
|
|
|
|
(0, slice(4, 8)), (0, slice(4, 8)), (0, slice(4, 8)),
|
|
|
|
(1, slice(0, 4)), (1, slice(0, 4)), (1, slice(0, 4)),
|
|
|
|
(1, slice(4, 8)), (1, slice(4, 8)), (1, slice(4, 8))))
|
|
|
|
|
|
|
|
def testReplicationPosition1(self):
|
|
|
|
shape = (2, 8)
|
2021-01-16 17:57:39 +01:00
|
|
|
spec = pxla.ShardingSpec(sharding=(pxla.Unstacked(2), pxla.Chunked([2])),
|
2020-11-06 12:55:17 +00:00
|
|
|
mesh_mapping=(pxla.ShardedAxis(0), pxla.Replicated(3), pxla.ShardedAxis(1)))
|
2020-06-01 16:50:22 -07:00
|
|
|
self.assertEqual(pxla.spec_to_indices(shape, spec),
|
|
|
|
((0, slice(0, 4)), (0, slice(4, 8)),
|
|
|
|
(0, slice(0, 4)), (0, slice(4, 8)),
|
|
|
|
(0, slice(0, 4)), (0, slice(4, 8)),
|
|
|
|
(1, slice(0, 4)), (1, slice(4, 8)),
|
|
|
|
(1, slice(0, 4)), (1, slice(4, 8)),
|
|
|
|
(1, slice(0, 4)), (1, slice(4, 8))))
|
|
|
|
|
|
|
|
def testReplicationPosition0(self):
|
|
|
|
shape = (2, 8)
|
2021-01-21 01:14:56 -08:00
|
|
|
spec = pxla.ShardingSpec(sharding=(pxla.Unstacked(2), pxla.NoSharding()),
|
2020-11-06 12:55:17 +00:00
|
|
|
mesh_mapping=(pxla.Replicated(3), pxla.ShardedAxis(0)))
|
2020-06-01 16:50:22 -07:00
|
|
|
self.assertEqual(pxla.spec_to_indices(shape, spec),
|
2020-11-06 12:55:17 +00:00
|
|
|
tuple([(0, slice(None)), (1, slice(None))] * 3))
|
2020-06-01 16:50:22 -07:00
|
|
|
|
|
|
|
def testMultipleReplications(self):
|
|
|
|
shape = (2, 7, 4)
|
2021-01-21 01:14:56 -08:00
|
|
|
spec = pxla.ShardingSpec(
|
2021-01-16 17:57:39 +01:00
|
|
|
sharding=(pxla.Unstacked(2), pxla.NoSharding(), pxla.Chunked([2])),
|
2021-01-21 01:14:56 -08:00
|
|
|
mesh_mapping=(pxla.Replicated(3), pxla.Replicated(2),
|
|
|
|
pxla.ShardedAxis(0), pxla.Replicated(2),
|
|
|
|
pxla.ShardedAxis(1)))
|
2020-06-01 16:50:22 -07:00
|
|
|
self.assertEqual(
|
|
|
|
pxla.spec_to_indices(shape, spec),
|
|
|
|
((0, slice(None), slice(0, 2)), (0, slice(None), slice(2, 4)),
|
|
|
|
(0, slice(None), slice(0, 2)), (0, slice(None), slice(2, 4)),
|
|
|
|
(1, slice(None), slice(0, 2)), (1, slice(None), slice(2, 4)),
|
|
|
|
(1, slice(None), slice(0, 2)), (1, slice(None), slice(2, 4))) * 3 * 2)
|
|
|
|
|
|
|
|
def testReplicatedScalar(self):
|
|
|
|
shape = ()
|
2020-11-06 12:55:17 +00:00
|
|
|
spec = pxla.ShardingSpec(sharding=(),
|
|
|
|
mesh_mapping=(pxla.Replicated(3),))
|
2020-06-01 16:50:22 -07:00
|
|
|
self.assertEqual(pxla.spec_to_indices(shape, spec),
|
|
|
|
((), (), ()))
|
|
|
|
|
2020-04-15 12:43:55 -07:00
|
|
|
|
2020-05-06 10:19:28 -07:00
|
|
|
def _spec_str(spec):
|
2020-11-06 12:55:17 +00:00
|
|
|
return (f"({spec.sharding},"
|
|
|
|
f"{spec.mesh_mapping},)")
|
2020-05-06 10:19:28 -07:00
|
|
|
|
|
|
|
|
|
|
|
class ShardArgsTest(jtu.JaxTestCase):
|
|
|
|
|
|
|
|
def numpy_array(x):
|
|
|
|
return x
|
|
|
|
|
|
|
|
def device_array(x):
|
|
|
|
return jax.device_put(x)
|
|
|
|
|
|
|
|
# TODO(skye): add coverage for ShardedDeviceArrays
|
|
|
|
|
|
|
|
@parameterized.named_parameters(
|
|
|
|
{"testcase_name":
|
|
|
|
f"_shape={shape}_spec={_spec_str(spec)}_arg={make_arg.__name__}"
|
|
|
|
.replace(" ", ""),
|
|
|
|
"shape": shape, "spec": spec, "make_arg": make_arg}
|
|
|
|
for make_arg in [numpy_array, device_array]
|
|
|
|
for shape, spec in [
|
|
|
|
# pmap(in_axes=0)
|
2021-01-21 01:14:56 -08:00
|
|
|
[(4, 8), pxla.ShardingSpec(sharding=(pxla.Unstacked(4), pxla.NoSharding()),
|
2020-11-06 12:55:17 +00:00
|
|
|
mesh_mapping=(pxla.ShardedAxis(0),))],
|
2020-05-06 10:19:28 -07:00
|
|
|
# pmap(in_axes=1)
|
2021-01-21 01:14:56 -08:00
|
|
|
[(2, 2), pxla.ShardingSpec(sharding=(pxla.NoSharding(), pxla.Unstacked(2)),
|
2020-11-06 12:55:17 +00:00
|
|
|
mesh_mapping=(pxla.ShardedAxis(0),))],
|
2020-05-06 10:19:28 -07:00
|
|
|
# unsharded
|
2021-01-21 01:14:56 -08:00
|
|
|
[(4, 8), pxla.ShardingSpec(sharding=(pxla.NoSharding(), pxla.NoSharding()),
|
2020-11-06 12:55:17 +00:00
|
|
|
mesh_mapping=())],
|
2020-05-06 10:19:28 -07:00
|
|
|
# partitioned, 1 axis
|
2021-01-16 17:57:39 +01:00
|
|
|
[(4, 8), pxla.ShardingSpec(sharding=(pxla.Chunked([2]), pxla.NoSharding()),
|
2020-11-06 12:55:17 +00:00
|
|
|
mesh_mapping=(pxla.ShardedAxis(0),))],
|
2020-05-06 10:19:28 -07:00
|
|
|
# partitioned, 2 axes
|
2021-01-16 17:57:39 +01:00
|
|
|
[(4, 8), pxla.ShardingSpec(sharding=(pxla.Chunked([2]), pxla.Chunked([2])),
|
2020-11-06 12:55:17 +00:00
|
|
|
mesh_mapping=map(pxla.ShardedAxis, (0, 1)))],
|
|
|
|
# partitioned, 2 axes, permuted
|
2021-01-16 17:57:39 +01:00
|
|
|
[(4, 8), pxla.ShardingSpec(sharding=(pxla.Chunked([2]), pxla.Chunked([2])),
|
2020-11-06 12:55:17 +00:00
|
|
|
mesh_mapping=map(pxla.ShardedAxis, (1, 0)))],
|
2020-05-21 13:52:03 -07:00
|
|
|
# partitioned + sharding
|
2021-01-16 17:57:39 +01:00
|
|
|
[(2, 8), pxla.ShardingSpec(sharding=(pxla.Unstacked(2), pxla.Chunked([2])),
|
2020-11-06 12:55:17 +00:00
|
|
|
mesh_mapping=map(pxla.ShardedAxis, (0, 1)))],
|
2020-05-06 10:19:28 -07:00
|
|
|
# replication + sharding
|
2021-01-21 01:14:56 -08:00
|
|
|
[(2, 8), pxla.ShardingSpec(sharding=(pxla.Unstacked(2), pxla.NoSharding()),
|
2020-11-06 12:55:17 +00:00
|
|
|
mesh_mapping=(pxla.ShardedAxis(0), pxla.Replicated(3)))],
|
2020-05-06 10:19:28 -07:00
|
|
|
# replication, no sharding
|
2021-01-21 01:14:56 -08:00
|
|
|
[(2, 8), pxla.ShardingSpec(sharding=(pxla.NoSharding(), pxla.NoSharding()),
|
2020-11-06 12:55:17 +00:00
|
|
|
mesh_mapping=(pxla.Replicated(3),))],
|
2020-06-01 16:50:22 -07:00
|
|
|
# multiple replicated axes
|
2021-01-16 17:57:39 +01:00
|
|
|
[(1, 8), pxla.ShardingSpec(sharding=(pxla.Unstacked(1), pxla.Chunked([2])),
|
2020-11-06 12:55:17 +00:00
|
|
|
mesh_mapping=(pxla.Replicated(2), pxla.ShardedAxis(0),
|
|
|
|
pxla.Replicated(2), pxla.ShardedAxis(1)))],
|
2020-06-01 16:50:22 -07:00
|
|
|
# replicated scalar
|
2020-11-06 12:55:17 +00:00
|
|
|
[(), pxla.ShardingSpec(sharding=(),
|
|
|
|
mesh_mapping=(pxla.Replicated(2), pxla.Replicated(3)))],
|
2020-05-06 10:19:28 -07:00
|
|
|
])
|
|
|
|
def testShardArgs(self, shape, spec, make_arg):
|
|
|
|
indices = pxla.spec_to_indices(shape, spec)
|
|
|
|
nshards = len(indices)
|
|
|
|
if jax.device_count() < nshards:
|
|
|
|
raise SkipTest
|
2020-08-18 10:17:38 -07:00
|
|
|
x = np.arange(prod(shape)).reshape(shape)
|
2020-05-06 10:19:28 -07:00
|
|
|
arg = make_arg(x)
|
|
|
|
bufs = pxla.shard_args(jax.devices()[:nshards],
|
|
|
|
[indices], [arg])
|
2021-01-29 11:39:10 -08:00
|
|
|
self.assertEqual(len(bufs), 1)
|
|
|
|
self.assertEqual(len(bufs[0]), nshards)
|
|
|
|
for buf, idx in zip(bufs[0], indices):
|
|
|
|
self.assertAllClose(buf.to_py(), x[idx], check_dtypes=False)
|
2020-05-06 10:19:28 -07:00
|
|
|
|
2020-08-14 18:22:04 +02:00
|
|
|
|
2019-01-28 11:13:34 -08:00
|
|
|
if __name__ == '__main__':
|
2020-06-24 16:24:33 -07:00
|
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|