rocm_jax/tests/pmap_test.py

2725 lines
98 KiB
Python
Raw Normal View History

2019-01-28 11:13:34 -08:00
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from concurrent.futures import ThreadPoolExecutor
2019-02-23 20:34:14 -08:00
from functools import partial
import itertools as it
import gc
import os
from random import shuffle
from typing import Optional, cast
import unittest
2021-03-29 13:58:04 -07:00
from unittest import SkipTest
2020-06-12 16:10:45 -07:00
import warnings
import weakref
2019-02-23 20:34:14 -08:00
import numpy as np
2019-01-28 11:13:34 -08:00
from absl.testing import absltest
from absl.testing import parameterized
import jax
import jax.numpy as jnp
from jax._src import test_util as jtu
from jax import tree_util
2019-02-23 20:34:14 -08:00
from jax import lax
2021-02-08 20:24:19 -08:00
from jax._src.lax import parallel
from jax._src import api as src_api
from jax import random
from jax.core import ShapedArray
2021-01-26 19:38:40 -08:00
from jax import (pmap, soft_pmap, jit, vmap, jvp, grad, make_jaxpr,
linearize, device_put)
from jax._src import device_array
import jax._src.lib
from jax._src.lib import xla_bridge
from jax._src.util import prod, safe_map
2019-03-19 16:54:55 -07:00
from jax.interpreters import pxla
2019-07-06 10:00:08 -07:00
from jax.interpreters import xla
2019-01-28 11:13:34 -08:00
from jax.config import config
config.parse_flags_with_absl()
prev_xla_flags = None
compatible_shapes = [[(3,)], [(3, 4), (3, 1), (1, 4)], [(2, 3, 4), (2, 1, 4)]]
Add support for non-zero (but still not-None) out_axes in pmap Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`), but its semantics would match the specification of `out_axes=0` (i.e. all outputs should be stacked along the first axis). This patch makes it possible to specify non-zero values for out_axes, but more importantly it lays down the groundwork for `xmap` which will have to use some extremely similar (if not the same) code paths. One thing to note is that when I started this implementation I was also planning to add support for `out_axes=None`, which would allow us to stop using the `unbroadcast` hack, and most of the code is written with that in mind. Unfortunately it turned out that the correct implementation of the transpose rule for maps that do allow unmapped outputs would require me to pretty much simulate what avals-with-names is supposed to achieve. Technically replicated outputs should work today, for as long as the user does not do reverse-mode AD of `pmap`. But I decided that it's better to just disable them altogether until we can get the full and correct behavior. * Implementation details * This patch is significantly more involved than the one that implemented general `in_axes` support. That previous one at least had the foundation of `mapped_invars` which already behaved pretty similarly to general `in_axes`. From a quick glance one might think that `out_axes` should behave similarly to `in_axes`, but it turns out that this is not the case, at least not if we're interested in keeping those primitives final-style. ** Thunking ** The biggest difficulty with handling `out_axes` in final style primitives is that we want to treat them as a prefix of the output pytree, but we don't know the structure of the output pytree until the user function is evaluated! And the user function is not evaluated until we've applied all transforms and reached the impl rule! The solution to this problem is "straightforward": instead of putting `out_axes` as a primitive parameter, we bundle an `out_axes_thunk` which can only be called successfully after the wrapped function has been executed. The thunk returns a list of flat `out_axes`, expanded to the output pytree. However, the thunking presents us with two problems: *** Transformations *** Each transformation that modifies the number of outputs needs to ensure that the thunk is updated to reflect the new values. To make things worse a lot of the transforms can learn the number of added outputs _only after the wrapped function is evaluated_, which leads to the following "time travel" pattern that can be found in most `Trace`s: ```py @lu.transformation_with_aux def compute_output_statistic(*args, **kwargs): outputs = yield args, kwargs yield outputs, compute_statistic(outputs) wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun) def new_out_axes_thunk(): old_out_axes = params['out_axes_thunk']() return compute_new_out_axes(old_out_axes(), output_statistic()) primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk)) ``` The reason why we have to structure the code this way is that we can only specify a new `out_axes_thunk` before we bind the primitive, but we need the outputs of bind to know how to update the `out_axes_thunk`. To make things worse, the implementation of `bind` is allowed to make a call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_. This means that we cannot compute the output statistic in the implementation of the transformation, but we have to use an extra `lu.transformation_with_aux` for that (this populates the statistic store immediately after `wrapped_fun` is evaluated). The `compute_statistic` function depends on the transform in question. E.g. in the JVP trace it counts the number of non-zero tangent results. The situation is of course further complicated when we take `post_process_map` into account. The new `process_env_traces` now always sets up this funny time travel trampoline just in case it ends up being necessary, and `post_process_map` is now expected to return `(outputs, (todo, out_axes_transform))` instead of just `(outputs, todo)`. *** Compilation cache *** Because the `out_axes_thunk`s are now arguments to a _global_ compilation cache (in the form of `lu.cache` decorator on `parallel_callable`), we have to ensure that they implement `hash` and `==`. This is what forces us to add some slightly weird helpers such as `_hashable_function` and `_ignore_elem_list`. The code that uses those makes an assumption that the output pytree depends deterministically on the identity of the wrapped function, which I think is in line with general JAX assumptions. Otherwise the cache would depend on the identity of the thunk, which changes with every function invocation. Relaxing the global constraint on the cache (e.g. allowing each `pmap(f)` instance to have a separate cache) would make this easier too. * Why final style? * Now, making the primitives initial-style would remove the necessity for thunking, because we could have obtained the output pytree right when the function is wrapped. I assumed there is a good argument for making `pmap` pretend that it's a final-style primitive, but I'm not sure why that is? I hope it's something better than just avoiding a single jaxpr tracing.
2020-11-09 17:23:16 +00:00
def all_bdims(*shapes, pmap):
2021-03-29 13:58:04 -07:00
bdims = (it.chain([cast(Optional[int], None)], range(len(shape) + 1))
for shape in shapes)
return (t for t in it.product(*bdims) if not all(e is None for e in t))
Add support for non-zero (but still not-None) out_axes in pmap Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`), but its semantics would match the specification of `out_axes=0` (i.e. all outputs should be stacked along the first axis). This patch makes it possible to specify non-zero values for out_axes, but more importantly it lays down the groundwork for `xmap` which will have to use some extremely similar (if not the same) code paths. One thing to note is that when I started this implementation I was also planning to add support for `out_axes=None`, which would allow us to stop using the `unbroadcast` hack, and most of the code is written with that in mind. Unfortunately it turned out that the correct implementation of the transpose rule for maps that do allow unmapped outputs would require me to pretty much simulate what avals-with-names is supposed to achieve. Technically replicated outputs should work today, for as long as the user does not do reverse-mode AD of `pmap`. But I decided that it's better to just disable them altogether until we can get the full and correct behavior. * Implementation details * This patch is significantly more involved than the one that implemented general `in_axes` support. That previous one at least had the foundation of `mapped_invars` which already behaved pretty similarly to general `in_axes`. From a quick glance one might think that `out_axes` should behave similarly to `in_axes`, but it turns out that this is not the case, at least not if we're interested in keeping those primitives final-style. ** Thunking ** The biggest difficulty with handling `out_axes` in final style primitives is that we want to treat them as a prefix of the output pytree, but we don't know the structure of the output pytree until the user function is evaluated! And the user function is not evaluated until we've applied all transforms and reached the impl rule! The solution to this problem is "straightforward": instead of putting `out_axes` as a primitive parameter, we bundle an `out_axes_thunk` which can only be called successfully after the wrapped function has been executed. The thunk returns a list of flat `out_axes`, expanded to the output pytree. However, the thunking presents us with two problems: *** Transformations *** Each transformation that modifies the number of outputs needs to ensure that the thunk is updated to reflect the new values. To make things worse a lot of the transforms can learn the number of added outputs _only after the wrapped function is evaluated_, which leads to the following "time travel" pattern that can be found in most `Trace`s: ```py @lu.transformation_with_aux def compute_output_statistic(*args, **kwargs): outputs = yield args, kwargs yield outputs, compute_statistic(outputs) wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun) def new_out_axes_thunk(): old_out_axes = params['out_axes_thunk']() return compute_new_out_axes(old_out_axes(), output_statistic()) primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk)) ``` The reason why we have to structure the code this way is that we can only specify a new `out_axes_thunk` before we bind the primitive, but we need the outputs of bind to know how to update the `out_axes_thunk`. To make things worse, the implementation of `bind` is allowed to make a call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_. This means that we cannot compute the output statistic in the implementation of the transformation, but we have to use an extra `lu.transformation_with_aux` for that (this populates the statistic store immediately after `wrapped_fun` is evaluated). The `compute_statistic` function depends on the transform in question. E.g. in the JVP trace it counts the number of non-zero tangent results. The situation is of course further complicated when we take `post_process_map` into account. The new `process_env_traces` now always sets up this funny time travel trampoline just in case it ends up being necessary, and `post_process_map` is now expected to return `(outputs, (todo, out_axes_transform))` instead of just `(outputs, todo)`. *** Compilation cache *** Because the `out_axes_thunk`s are now arguments to a _global_ compilation cache (in the form of `lu.cache` decorator on `parallel_callable`), we have to ensure that they implement `hash` and `==`. This is what forces us to add some slightly weird helpers such as `_hashable_function` and `_ignore_elem_list`. The code that uses those makes an assumption that the output pytree depends deterministically on the identity of the wrapped function, which I think is in line with general JAX assumptions. Otherwise the cache would depend on the identity of the thunk, which changes with every function invocation. Relaxing the global constraint on the cache (e.g. allowing each `pmap(f)` instance to have a separate cache) would make this easier too. * Why final style? * Now, making the primitives initial-style would remove the necessity for thunking, because we could have obtained the output pytree right when the function is wrapped. I assumed there is a good argument for making `pmap` pretend that it's a final-style primitive, but I'm not sure why that is? I hope it's something better than just avoiding a single jaxpr tracing.
2020-11-09 17:23:16 +00:00
def out_bdims(shape, pmap):
return (d[0] for d in all_bdims(shape, pmap=pmap) if d[0] is not None)
def add_bdim(bdim_size, bdim, shape):
shape = list(shape)
if bdim is not None:
shape.insert(bdim, bdim_size)
return tuple(shape)
def slicer(x, bdim):
if bdim is None:
return lambda _: x
else:
return lambda i: lax.index_in_dim(x, i, bdim, keepdims=False)
def args_slicer(args, bdims):
slicers = safe_map(slicer, args, bdims)
return lambda i: [sl(i) for sl in slicers]
# Run all tests with 8 CPU devices.
def setUpModule():
global prev_xla_flags
prev_xla_flags = os.getenv("XLA_FLAGS")
flags_str = prev_xla_flags or ""
# Don't override user-specified device count, or other XLA flags.
if "xla_force_host_platform_device_count" not in flags_str:
os.environ["XLA_FLAGS"] = (flags_str +
" --xla_force_host_platform_device_count=8")
# Clear any cached backends so new CPU backend will pick up the env var.
xla_bridge.get_backend.cache_clear()
# Reset to previous configuration in case other test modules will be run.
def tearDownModule():
if prev_xla_flags is None:
del os.environ["XLA_FLAGS"]
else:
os.environ["XLA_FLAGS"] = prev_xla_flags
xla_bridge.get_backend.cache_clear()
2019-01-28 11:13:34 -08:00
ignore_jit_of_pmap_warning = partial(
jtu.ignore_warning, message=".*jit-of-pmap.*")
ignore_slow_all_to_all_warning = partial(
jtu.ignore_warning, message="all_to_all.*expect significant slowdowns.*")
2019-02-23 20:34:14 -08:00
ignore_xmap_warning = partial(
jtu.ignore_warning, message=".*is an experimental.*")
class PythonPmapTest(jtu.JaxTestCase):
@property
def pmap(self):
return src_api._python_pmap
2021-10-08 10:41:43 -07:00
def testDeviceBufferToArray(self):
sda = self.pmap(lambda x: x)(jnp.ones((jax.device_count(), 2)))
buf = sda.device_buffers[-1]
view = jnp.array(buf, copy=False)
self.assertArraysEqual(sda[-1], view)
self.assertEqual(buf.device(), view.device())
self.assertEqual(buf.unsafe_buffer_pointer(), view.unsafe_buffer_pointer())
copy = jnp.array(buf, copy=True)
self.assertArraysEqual(sda[-1], copy)
self.assertEqual(buf.device(), copy.device())
self.assertNotEqual(buf.unsafe_buffer_pointer(), copy.unsafe_buffer_pointer())
def _getMeshShape(self, device_mesh_shape):
device_count = jax.device_count()
if any(size == -1 for size in device_mesh_shape):
try:
return np.arange(device_count).reshape(device_mesh_shape).shape
except ValueError as err:
msg = "device mesh shape {} not compatible with device count {}"
raise SkipTest(msg.format(device_mesh_shape, device_count)) from err
else:
if device_count % prod(device_mesh_shape):
msg = "device mesh size {} does not divide available device count {}"
raise SkipTest(msg.format(prod(device_mesh_shape), device_count))
else:
return device_mesh_shape
2019-03-19 16:54:55 -07:00
def testBasic(self):
f = self.pmap(lambda x: x - lax.psum(x, 'i'), axis_name='i')
2019-03-19 16:54:55 -07:00
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
expected = x - np.sum(x, 0)
2019-03-19 16:54:55 -07:00
ans = f(x)
self.assertAllClose(ans, expected, check_dtypes=False)
def testLowerCompile(self):
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
expected = f(x)
lowered = f.lower(x)
compiled = lowered.compile()
ans = compiled(x)
self.assertAllClose(ans, expected)
# It's a pair of: (positional args, as a tuple of their structures, kwargs).
for obj in [lowered, compiled]:
self.assertFalse(obj._no_kwargs)
self.assertEqual(obj.in_tree, jax.tree_flatten(((0,), {}))[1])
self.assertEqual(obj.in_avals, ((jax.ShapedArray(x.shape, x.dtype),), {}))
def testLowerCompileInTreeMismatch(self):
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
f_exe = f.lower(x).compile()
self.assertRaisesRegex(
TypeError, "function compiled for .*, called with .*",
lambda: f_exe([x]))
def testLowerCompileTrivial(self):
f = self.pmap(lambda x: x, axis_name='i')
x = np.arange(jax.device_count(), dtype=np.float32)
expected = f(x)
f_exe = f.lower(x).compile()
ans = f_exe(x)
self.assertAllClose(ans, expected)
def testLowerCompileTrivialInTreeMismatch(self):
f = self.pmap(lambda x: x, axis_name='i')
x = np.arange(jax.device_count(), dtype=np.float32)
f_exe = f.lower(x).compile()
self.assertRaisesRegex(
TypeError, "function compiled for .*, called with .*",
lambda: f_exe([x]))
def testLowerCompileArgTypeMismatch(self):
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=int).reshape(shape)
x_f32 = x.astype(jnp.float32)
x_i32 = x.astype(jnp.int32)
f_exe = f.lower(x_f32).compile()
self.assertRaisesRegex(
TypeError,
"Computation compiled for input types:\n.*float32.*\n"
"called with:\n.*int32.*",
lambda: f_exe(x_i32))
def testLowerCompileMultiArg(self):
f = self.pmap(lambda x, y: x - lax.pmean(y, 'i'), axis_name='i')
shape = (jax.device_count(), 4)
x = y = np.arange(prod(shape), dtype=np.float32).reshape(shape)
expected = f(x, y)
f_exe = f.lower(x, y).compile()
ans = f_exe(x, y)
self.assertAllClose(ans, expected)
def testLowerCompileTrivialMultiArg(self):
f = self.pmap(lambda x, y: (x, y), axis_name='i')
x = y = np.arange(jax.device_count(), dtype=np.float32)
expected = f(x, y)
f_exe = f.lower(x, y).compile()
ans = f_exe(x, y)
self.assertAllClose(ans, expected)
def testLowerCompilerIR(self):
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
f = f.lower(x)
self.assertIsNotNone(f.compiler_ir())
self.assertIsNotNone(f.compiler_ir(dialect='hlo'))
self.assertIsNotNone(f.compiler_ir(dialect='mhlo'))
def testLowerCompileCompilerIR(self):
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
f = f.lower(x).compile()
self.assertIsNotNone(f.compiler_ir())
def testLowerCompileExecutable(self):
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
f = f.lower(x).compile()
self.assertIsNotNone(f.runtime_executable())
2020-01-29 18:10:48 +00:00
def testMean(self):
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
2020-01-29 18:10:48 +00:00
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
expected = x - np.broadcast_to(np.mean(x, 0), x.shape)
2020-01-29 18:10:48 +00:00
ans = f(x)
self.assertAllClose(ans, expected, check_dtypes=False)
def testGather(self):
f = self.pmap(lambda x: lax.all_gather(x, 'i'), axis_name='i')
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
expected = np.array([x] * jax.device_count())
ans = f(x)
self.assertAllClose(ans, expected, check_dtypes=False)
def testGatherBool(self):
f = self.pmap(lambda x: lax.all_gather(x, 'i'), axis_name='i')
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
x = (x % 2).astype(np.bool_)
expected = np.array([x] * jax.device_count())
ans = f(x)
self.assertAllClose(ans, expected, check_dtypes=False)
def testGatherTiled(self):
f = self.pmap(lambda x: lax.all_gather(x, 'i', tiled=True), axis_name='i')
device_count = jax.device_count()
shape = (device_count, 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
expected = np.array([x] * device_count).reshape(device_count, -1)
ans = f(x)
self.assertAllClose(ans, expected, check_dtypes=False)
def testReduceScatter(self):
f = self.pmap(lambda x: lax.psum_scatter(x, 'i'), axis_name='i')
device_count = jax.device_count()
shape = (device_count, device_count)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
expected = np.sum(x, axis=0)
ans = f(x)
for i, actual in enumerate(ans):
self.assertAllClose(actual, expected[i])
def testReduceScatterTiled(self):
f = self.pmap(lambda x: lax.psum_scatter(x, 'i', tiled=True), axis_name='i')
device_count = jax.device_count()
shape = (device_count, 4 * device_count)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
expected = np.sum(x, axis=0)
ans = f(x)
scatter_len = len(expected) // device_count
for i, actual in enumerate(ans):
self.assertAllClose(actual,
expected[i * scatter_len:(i + 1) * scatter_len])
def testReduceScatterReplicaGroupsTiled(self):
replicas = jax.device_count()
if replicas % 2 != 0:
raise SkipTest
axis_index_groups = [[i for i in range(jax.device_count()) if i % 2 == 0],
[i for i in range(jax.device_count()) if i % 2 != 0]]
f = lambda x: lax.psum_scatter(
x, 'i', axis_index_groups=axis_index_groups, tiled=True)
f = self.pmap(f, axis_name='i')
shape = (replicas, 4 * replicas)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
ans = f(x)
group_1_result = np.sum(x[0::2,:], axis=0)
group_2_result = np.sum(x[1::2,:], axis=0)
# the result is scattered over (replicas // 2) devices
scatter_len = len(group_1_result) * 2 // replicas
for i, actual in enumerate(ans):
expected = group_1_result if i % 2 == 0 else group_2_result
self.assertAllClose(
actual, expected[i // 2 * scatter_len:(i // 2 + 1) * scatter_len])
@ignore_slow_all_to_all_warning()
def testTrees(self):
ptranspose = lambda x, axis_name: lax.all_to_all(x, axis_name, 0, 0)
def protate(x, axis_name):
n = lax.psum(1, axis_name)
return lax.ppermute(x, axis_name, [(i, (i + 1) % n) for i in range(n)])
tree_f = lambda f: partial(tree_util.tree_map, f)
jax_f = lambda p: self.pmap(lambda x: p(x, 'i'), 'i')
np_f = lambda p: tree_f(lambda x: np.broadcast_to(p(x, 0), x.shape))
np_transpose = tree_f(np.transpose)
np_rotate = tree_f(lambda x: np.concatenate([x[-1:], x[:-1]]))
n = jax.device_count()
x = {'a': np.arange(1 * n * n, 2 * n * n).reshape([n, n]),
'b': np.arange(2 * n * n, 3 * n * n).reshape([n, n]),
'c': np.arange(4 * n * n, 5 * n * n).reshape([n, n])}
2022-04-01 14:51:54 -07:00
assert_allclose = partial(tree_util.tree_map,
partial(self.assertAllClose, check_dtypes=False))
assert_allclose(jax_f(lax.pmax)(x), np_f(np.max)(x))
assert_allclose(jax_f(lax.pmin)(x), np_f(np.min)(x))
assert_allclose(jax_f(lax.psum)(x), np_f(np.sum)(x))
assert_allclose(jax_f(lax.pmean)(x), np_f(np.mean)(x))
assert_allclose(jax_f(ptranspose)(x), np_transpose(x))
assert_allclose(jax_f(protate)(x), np_rotate(x))
def testCollectivesWithTreesOfDifferentDtypes(self):
n = len(jax.devices())
x = {'a': np.arange(1 * n * n, 2 * n * n, dtype=np.float32).reshape([n, n]),
'b': np.arange(2 * n * n, 3 * n * n, dtype=np.int32).reshape([n, n]),
'c': np.arange(4 * n * n, 5 * n * n, dtype=np.float32).reshape([n, n]),
'd': np.arange(6 * n * n, 7 * n * n, dtype=np.int32).reshape([n, n])}
tree_f = lambda f: partial(tree_util.tree_map, f)
jax_f = lambda p: self.pmap(lambda x: p(x, 'i'), 'i')
np_f = lambda p: tree_f(lambda x: np.broadcast_to(p(x, 0), x.shape))
2022-04-01 14:51:54 -07:00
assert_allclose = partial(tree_util.tree_map,
partial(self.assertAllClose, check_dtypes=False))
assert_allclose(jax_f(lax.pmax)(x), np_f(np.max)(x))
assert_allclose(jax_f(lax.pmin)(x), np_f(np.min)(x))
assert_allclose(jax_f(lax.psum)(x), np_f(np.sum)(x))
assert_allclose(jax_f(lax.pmean)(x), np_f(np.mean)(x))
def testComplexPsum(self):
f = self.pmap(lambda x: x - lax.psum(x, 'i'), axis_name='i')
shape = (jax.device_count(), 4 * 2)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape).view(np.complex64)
expected = x - np.sum(x, 0)
ans = f(x)
self.assertAllClose(ans, expected, check_dtypes=False)
@parameterized.named_parameters(
{"testcase_name": f"_split={split_axis}_concat={concat_axis}",
"split_axis": split_axis, "concat_axis": concat_axis}
for split_axis, concat_axis in it.product(range(2), range(2)))
@ignore_slow_all_to_all_warning()
def testAllToAll(self, split_axis, concat_axis):
pmap_in_axis = 0
shape = (jax.device_count(),) * 3
x = np.arange(np.prod(shape)).reshape(shape)
@partial(self.pmap, axis_name='i')
def f(x):
return lax.all_to_all(x, 'i', split_axis, concat_axis)
y = f(x)
if pmap_in_axis <= split_axis:
split_axis += 1
ref = jnp.moveaxis(x, (pmap_in_axis, split_axis),
(concat_axis + 1, 0))
self.assertAllClose(y, ref)
@parameterized.named_parameters(
{"testcase_name": f"_split={split_axis}_concat={concat_axis}",
"split_axis": split_axis, "concat_axis": concat_axis}
for split_axis, concat_axis in it.product(range(2), range(2)))
@ignore_slow_all_to_all_warning()
def testAllToAllSplitAxis(self, split_axis, concat_axis):
if jax.device_count() < 4:
raise SkipTest("test requires at least four devices")
pmap_in_axis = 0
shape = (4, 4, 4)
x = np.arange(np.prod(shape)).reshape(shape)
@partial(self.pmap, axis_name='i')
@partial(self.pmap, axis_name='j')
def f(x):
return lax.all_to_all(x, ('i', 'j'), split_axis, concat_axis)
unroll_shape = (2, 2, *shape[1:])
x_unroll = x.reshape(unroll_shape)
y_unroll = f(x_unroll)
y = y_unroll.reshape(shape)
if pmap_in_axis <= split_axis:
split_axis += 1
ref = jnp.moveaxis(x, (pmap_in_axis, split_axis),
(concat_axis + 1, 0))
self.assertAllClose(y, ref)
2019-03-19 16:54:55 -07:00
def testNestedBasic(self):
f = lambda x: lax.psum(lax.psum(x, 'i'), 'j')
f = self.pmap(self.pmap(f, 'i'), 'j')
2019-03-19 16:54:55 -07:00
2019-03-20 17:46:16 -07:00
def sum_and_broadcast(x, axis):
return np.repeat(np.sum(x, axis, keepdims=True), x.shape[axis], axis)
2019-03-20 17:46:16 -07:00
shape = (jax.device_count(), 1, 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
2019-03-19 16:54:55 -07:00
ans = f(x)
2019-03-20 17:46:16 -07:00
expected = sum_and_broadcast(sum_and_broadcast(x, 0), 1)
2019-03-19 16:54:55 -07:00
self.assertAllClose(ans, expected, check_dtypes=False)
def testMismatchedAxisSizes(self):
n = jax.device_count()
f = self.pmap(lambda x, y: x + y)
self.assertRaisesRegex(
ValueError,
"pmap got inconsistent sizes for array axes to be mapped",
lambda: f(self.rng().randn(n), self.rng().randn(n - 1)))
2019-03-19 16:54:55 -07:00
@parameterized.named_parameters(
Allow ShardedDeviceArrays to represent arbitrary data shardings. (#2142) This change introduces ShardingSpec, a struct describing how an array should be sharded. This is integrated into ShardedDeviceArray to allow more flexible sharding. It supports partitioning (both "pmap-style", where an entire axis is decomposed into separate shards and doesn't appear in the on-device shape at all, and "sharded_jit-style", where an axis is chunked into shards but remains in the on-device shape) and replication. This removes the need for ChunkedDeviceArray, since a ShardedDeviceArray can now represent chunks. Here are pmap_benchmark times showing that the overall effect of this change neutral to positive (integer indexing is much faster!). **pmap_shard_args** ``` ---------Benchmark summary for pmap_shard_args--------- nargs nshards mean %std relative mean/baseline ------- --------- --------- --------- ---------- --------------- 10 8 0.041855 4.15223 1 1.01466 100 8 0.129884 4.85321 3.1032 0.988543 101 8 0.136347 6.20233 3.2576 0.967138 500 8 0.533207 3.6815 12.7394 1.0294 1000 8 1.10338 0.525193 26.362 0.960435 5000 8 5.33911 0 127.562 0.963319 100 2 0.0638619 10.7069 1.52579 1.0362 100 4 0.0868253 6.76701 2.07443 0.967323 100 8 0.128151 6.46004 3.06177 0.979742 100 100 1.22631 1.94885 29.299 1.00371 100 500 6.60746 0 157.865 0.956657 ``` **pmap_shard_outputs** ``` nouts nshards mean %std relative mean/baseline ------- --------- ---------- --------- ---------- --------------- 10 8 0.0664526 9.49251 1 0.938466 100 8 0.195711 2.19429 2.94512 1.04239 500 8 0.82577 0.330864 12.4265 0.994669 1000 8 1.68323 1.0516 25.3298 0.966915 5000 8 8.89032 0 133.784 0.998038 100 2 0.074806 10.1734 1.12571 0.980254 100 4 0.121334 5.76774 1.82588 1.02033 100 8 0.185253 5.45068 2.78775 1.01666 100 100 2.37076 0 35.6759 1.08629 100 500 17.0832 0 257.074 0.976879 ``` **ShardedDeviceArray_indexing** ``` indices_fn mean %std relative mean/baseline ------------------ ---------- ------- ---------- --------------- integer_indices 0.0603473 8.29159 1 0.359496 integer_2D_indices 18.0241 0 298.672 1.00583 ``` This is how I ran the benchmark: ``` TARGET_TOTAL_SECS=2 CUDA_VISIBLE_DEVICES= XLA_FLAGS=--xla_force_host_platform_device_count=500 python3 benchmarks/pmap_benchmark.py --baseline_dir=<results as of a3cc9a7> ```
2020-04-15 12:43:55 -07:00
{"testcase_name": "_mesh={}".format(device_mesh_shape).replace(" ", ""),
2019-03-19 16:54:55 -07:00
"device_mesh_shape": device_mesh_shape}
for device_mesh_shape in [(1, 1), (2, -1), (-1, 2)])
def testNestedShardingAndStacking(self, device_mesh_shape):
mesh_shape = self._getMeshShape(device_mesh_shape)
2019-03-19 16:54:55 -07:00
f = lambda x: x
f = self.pmap(self.pmap(f, 'i'), 'j')
2019-03-19 16:54:55 -07:00
shape = mesh_shape + (4,)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
2019-03-20 17:46:16 -07:00
2019-03-19 16:54:55 -07:00
ans = f(x)
2019-03-20 17:46:16 -07:00
expected = x
2019-03-19 16:54:55 -07:00
self.assertEqual(ans.shape, expected.shape)
self.assertAllClose(ans, expected, check_dtypes=False)
def testPartiallyMapped(self):
f = self.pmap(lambda x, y: x, in_axes=(None, 0))
g = self.pmap(lambda x, y: x - lax.psum(y, 'i'), axis_name='i', in_axes=(None, 0))
mesh_shape = (jax.device_count(),)
shape = mesh_shape + (4,)
x = np.array(3., dtype=np.float32)
y = np.arange(prod(shape), dtype=np.float32).reshape(shape)
f_expected = np.broadcast_to(x, mesh_shape)
f_ans = f(x, y)
self.assertAllClose(f_ans, f_expected)
self.assertIsInstance(f_ans, pxla.ShardedDeviceArray)
# the output is actually replicated (has the same values in each device buffer)
# but out_axes is implicitly 0, so we shouldn't have replication in the
# sharding spec.
self.assertEmpty([a for a in f_ans.sharding_spec.mesh_mapping
if isinstance(a, pxla.Replicated)])
g_expected = np.broadcast_to(x - np.sum(y, 0, keepdims=True), shape)
g_ans = g(x, y)
self.assertAllClose(g_ans, g_expected)
self.assertIsInstance(g_ans, pxla.ShardedDeviceArray)
self.assertEmpty([a for a in g_ans.sharding_spec.mesh_mapping
if isinstance(a, pxla.Replicated)])
def testReplicate(self):
base = np.array([3.,4.], dtype=np.float32)
num_devices = jax.device_count()
replicated = pxla.replicate(base, num_devices, num_devices, in_axis=None)
self.assertAllClose(base, replicated)
self.assertEmpty([a for a in replicated.sharding_spec.mesh_mapping
if not isinstance(a, pxla.Replicated)])
@parameterized.named_parameters(
{"testcase_name": "_mesh={}".format(device_mesh_shape).replace(" ", ""),
"device_mesh_shape": device_mesh_shape}
for device_mesh_shape in [(1, 1), (2, -1), (-1, 2)])
def testPartiallyMappedNested(self, device_mesh_shape):
mesh_shape = self._getMeshShape(device_mesh_shape)
f = self.pmap(lambda x, y: x - lax.psum(y, 'i'), axis_name='i', in_axes=(None, 0))
f = self.pmap(f, axis_name='j', in_axes=(None, 0))
x = 3.
y = np.arange(prod(mesh_shape), dtype=np.float32).reshape(mesh_shape)
expected = np.broadcast_to(x - np.sum(y, 1, keepdims=True), mesh_shape)
ans = f(x, y)
self.assertAllClose(ans, expected, check_dtypes=False)
2019-03-19 16:54:55 -07:00
def testJvpAndPartialEval(self):
@partial(self.pmap, axis_name='i')
2019-03-19 16:54:55 -07:00
def f(x):
return jnp.sin(x)
2019-03-19 16:54:55 -07:00
def splitjvp(x):
_, jvp = linearize(f, x)
return jvp(jnp.ones_like(x))
2019-03-19 16:54:55 -07:00
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
expected = np.cos(x)
2019-03-19 16:54:55 -07:00
ans = splitjvp(x)
self.assertAllClose(ans, expected, check_dtypes=False)
make_jaxpr(splitjvp)(x) # doesn't crash
def testGradBasic(self):
@partial(self.pmap, axis_name='i')
2019-03-19 16:54:55 -07:00
def f(x):
return jnp.sin(x)
2019-03-19 16:54:55 -07:00
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
2019-03-19 16:54:55 -07:00
ans = grad(lambda x: jnp.sum(jnp.sin(x)))(x)
expected = grad(lambda x: jnp.sum(f(x)))(x)
2019-03-19 16:54:55 -07:00
self.assertAllClose(ans, expected, check_dtypes=False)
def testGradOfPsum(self):
@partial(self.pmap, axis_name='i')
def f(x):
return lax.psum(x, axis_name='i')
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
jtu.check_grads(f, (x,), 2, ["fwd", "rev"], 1e-2, 1e-2, eps=1.)
2019-03-19 16:54:55 -07:00
def testGradOfJvp(self):
@partial(self.pmap, axis_name='i')
2019-03-19 16:54:55 -07:00
def f(x):
return jnp.sin(x)
2019-03-19 16:54:55 -07:00
def splitjvp(x):
_, jvp = linearize(f, x)
return jvp(jnp.ones_like(x))
2019-03-19 16:54:55 -07:00
fun = lambda x: jnp.sum(jvp(jnp.sin, (x,), (jnp.ones_like(x),))[1])
2019-03-19 16:54:55 -07:00
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
2019-03-19 16:54:55 -07:00
ans = grad(lambda x: jnp.sum(splitjvp(x)))(x)
2019-03-19 16:54:55 -07:00
expected = grad(fun)(x)
self.assertAllClose(ans, expected)
2019-03-19 16:54:55 -07:00
def testTwoArgsGrad(self):
def f(x, y):
return lax.psum(5. * jnp.cos(x) * jnp.sin(y), 'i')
f = self.pmap(f, 'i')
2019-03-19 16:54:55 -07:00
def g(x, y):
tot = jnp.sum(5. * jnp.cos(x) * jnp.sin(y))
return tot * jnp.ones_like(x) # broadcast to map like pjit does
2019-03-19 16:54:55 -07:00
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
2019-03-19 16:54:55 -07:00
y = 4 + x
ans = grad(lambda x, y: jnp.sum(g(x, y)))(x, y)
expected = grad(lambda x, y: jnp.sum(g(x, y)))(x, y)
2019-03-19 16:54:55 -07:00
self.assertAllClose(ans, expected, check_dtypes=False)
@parameterized.named_parameters(
Allow ShardedDeviceArrays to represent arbitrary data shardings. (#2142) This change introduces ShardingSpec, a struct describing how an array should be sharded. This is integrated into ShardedDeviceArray to allow more flexible sharding. It supports partitioning (both "pmap-style", where an entire axis is decomposed into separate shards and doesn't appear in the on-device shape at all, and "sharded_jit-style", where an axis is chunked into shards but remains in the on-device shape) and replication. This removes the need for ChunkedDeviceArray, since a ShardedDeviceArray can now represent chunks. Here are pmap_benchmark times showing that the overall effect of this change neutral to positive (integer indexing is much faster!). **pmap_shard_args** ``` ---------Benchmark summary for pmap_shard_args--------- nargs nshards mean %std relative mean/baseline ------- --------- --------- --------- ---------- --------------- 10 8 0.041855 4.15223 1 1.01466 100 8 0.129884 4.85321 3.1032 0.988543 101 8 0.136347 6.20233 3.2576 0.967138 500 8 0.533207 3.6815 12.7394 1.0294 1000 8 1.10338 0.525193 26.362 0.960435 5000 8 5.33911 0 127.562 0.963319 100 2 0.0638619 10.7069 1.52579 1.0362 100 4 0.0868253 6.76701 2.07443 0.967323 100 8 0.128151 6.46004 3.06177 0.979742 100 100 1.22631 1.94885 29.299 1.00371 100 500 6.60746 0 157.865 0.956657 ``` **pmap_shard_outputs** ``` nouts nshards mean %std relative mean/baseline ------- --------- ---------- --------- ---------- --------------- 10 8 0.0664526 9.49251 1 0.938466 100 8 0.195711 2.19429 2.94512 1.04239 500 8 0.82577 0.330864 12.4265 0.994669 1000 8 1.68323 1.0516 25.3298 0.966915 5000 8 8.89032 0 133.784 0.998038 100 2 0.074806 10.1734 1.12571 0.980254 100 4 0.121334 5.76774 1.82588 1.02033 100 8 0.185253 5.45068 2.78775 1.01666 100 100 2.37076 0 35.6759 1.08629 100 500 17.0832 0 257.074 0.976879 ``` **ShardedDeviceArray_indexing** ``` indices_fn mean %std relative mean/baseline ------------------ ---------- ------- ---------- --------------- integer_indices 0.0603473 8.29159 1 0.359496 integer_2D_indices 18.0241 0 298.672 1.00583 ``` This is how I ran the benchmark: ``` TARGET_TOTAL_SECS=2 CUDA_VISIBLE_DEVICES= XLA_FLAGS=--xla_force_host_platform_device_count=500 python3 benchmarks/pmap_benchmark.py --baseline_dir=<results as of a3cc9a7> ```
2020-04-15 12:43:55 -07:00
{"testcase_name": "_mesh={}".format(device_mesh_shape).replace(" ", ""),
2019-03-19 16:54:55 -07:00
"device_mesh_shape": device_mesh_shape}
for device_mesh_shape in [(1, 1), (2, -1), (-1, 2)])
def testNestedWithClosure(self, device_mesh_shape):
mesh_shape = self._getMeshShape(device_mesh_shape)
2019-02-23 20:34:14 -08:00
@partial(self.pmap, axis_name='i')
2019-02-23 20:34:14 -08:00
def test_fun(x):
y = jnp.sum(jnp.sin(x))
2019-02-23 20:34:14 -08:00
@partial(self.pmap, axis_name='j')
2019-02-23 20:34:14 -08:00
def g(z):
return 3. * jnp.exp(jnp.sin(x).sum() * jnp.cos(y) * jnp.tan(z))
2019-02-23 20:34:14 -08:00
return grad(lambda w: jnp.sum(g(w)))(x)
2019-02-23 20:34:14 -08:00
@vmap
def baseline_fun(x):
y = jnp.sum(jnp.sin(x))
2019-02-23 20:34:14 -08:00
@vmap
def g(z):
return 3. * jnp.exp(jnp.sin(x).sum() * jnp.cos(y) * jnp.tan(z))
2019-02-23 20:34:14 -08:00
return grad(lambda w: jnp.sum(g(w)))(x)
2019-02-23 20:34:14 -08:00
2019-03-19 16:54:55 -07:00
shape = mesh_shape + (4,)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
2019-03-19 16:54:55 -07:00
ans = grad(lambda x: jnp.sum(test_fun(x)))(x)
expected = grad(lambda x: jnp.sum(baseline_fun(x)))(x)
self.assertAllClose(ans, expected, atol=1e-3)
2019-02-01 16:59:28 -08:00
def testShardedDeviceArrays(self):
2019-03-19 16:54:55 -07:00
f = lambda x: 2 * x
f = self.pmap(f, axis_name='i')
2019-03-19 16:54:55 -07:00
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
2019-03-19 16:54:55 -07:00
# test that we can pass in and out ShardedDeviceArrays
y = f(x)
self.assertIsInstance(y, jnp.ndarray)
self.assertIsInstance(y, pxla.ShardedDeviceArray)
self.assertIsInstance(y, device_array.DeviceArray)
self.assertNotIsInstance(y, np.ndarray)
2019-03-19 16:54:55 -07:00
self.assertAllClose(y, 2 * x, check_dtypes=False)
z = f(y)
self.assertIsInstance(z, pxla.ShardedDeviceArray)
self.assertIsInstance(z, device_array.DeviceArray)
self.assertNotIsInstance(z, np.ndarray)
2019-03-19 16:54:55 -07:00
self.assertAllClose(z, 2 * 2 * x, check_dtypes=False)
# test that we can pass in a regular DeviceArray
y = f(device_put(x))
self.assertIsInstance(y, pxla.ShardedDeviceArray)
2019-03-19 16:54:55 -07:00
self.assertAllClose(y, 2 * x, check_dtypes=False)
# test that we can pass a ShardedDeviceArray to a regular jit computation
z = y + y
self.assertAllClose(z, 2 * 2 * x, check_dtypes=False)
# test that we can handle device movement on dispatch
y = pxla.make_sharded_device_array(y.aval, y.sharding_spec,
y.device_buffers[::-1])
2019-03-19 16:54:55 -07:00
z = f(y)
self.assertAllClose(z, 2 * 2 * x[::-1], check_dtypes=False)
# test that the repr doesn't crash
repr(z)
# test that we can lexically capture a sda as a constant.
g = jit(lambda z: z + y)
self.assertAllClose(g(7), y + 7)
# Tests edge cases in lax._reshape_sharded_device_array
@parameterized.named_parameters(
{"testcase_name": "_in={}_out={}".format(in_shape, out_shape)
.replace(" ", ""),
"in_shape": in_shape, "out_shape": out_shape}
for in_shape, out_shape in [
[(1,1), (1,)], [(1,), (1,1)], [(1,), ()], [(4,7), (2,2,7)]
])
def testShardedDeviceArrayReshape(self, in_shape, out_shape):
if jax.device_count() < max(in_shape[:1] + out_shape[:1]):
raise SkipTest("not enough devices")
x = np.arange(prod(in_shape)).reshape(in_shape)
sharded_x = self.pmap(lambda x: x)(x)
self.assertAllClose(sharded_x.reshape(out_shape), x.reshape(out_shape),
check_dtypes=False)
def testPsumMultiple(self):
f = lambda x: lax.psum(x, ('i', 'j'))
f = self.pmap(self.pmap(f, 'i'), 'j')
def sum_and_broadcast(x, axis):
return np.repeat(np.sum(x, axis, keepdims=True), x.shape[axis], axis)
device_count = jax.device_count()
num_pairs, ragged = divmod(device_count, 2)
if num_pairs > 1 and not ragged:
shape = (num_pairs, 2, 4)
else:
shape = (device_count, 1, 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
ans = f(x)
expected = sum_and_broadcast(sum_and_broadcast(x, 0), 1)
self.assertAllClose(ans, expected, check_dtypes=False)
def testPsumConstantReplicaGroups(self):
replicas = jax.device_count()
if replicas % 2 != 0:
raise SkipTest
axis_index_groups = np.arange(replicas).reshape(
2, replicas // 2).tolist()
f = lambda x: x - lax.psum(2., 'i', axis_index_groups=axis_index_groups)
f = self.pmap(f, 'i')
shape = (replicas, 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
expected_psum = 2. * replicas // 2
expected = x - expected_psum
ans = f(x)
self.assertAllClose(ans, expected, check_dtypes=False)
@jtu.skip_on_devices("tpu")
def testPsumUnevenReplicaGroups(self):
replicas = jax.device_count()
if replicas <= 2:
raise SkipTest("Test expected devices greater than 2.")
axis_index_groups = [[0,1], np.arange(2,replicas)]
f = lambda x: x - lax.psum(x, 'i', axis_index_groups=axis_index_groups)
f = self.pmap(f, 'i')
shape = (replicas, 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
def sum_helper(a):
return np.broadcast_to(a.sum(0, keepdims=True),
(len(a), x.shape[1]))
expected_psum_1 = sum_helper(x[0:2])
expected_psum_2 = sum_helper(x[2:])
expected_psum = np.concatenate([expected_psum_1, expected_psum_2], 0)
expected = x - expected_psum
ans = f(x)
self.assertAllClose(ans, expected, check_dtypes=False)
def testPsumReplicaGroups(self):
replicas = jax.device_count()
if replicas % 2 != 0:
raise SkipTest
axis_index_groups = np.arange(replicas).reshape(
2, replicas // 2).tolist()
f = lambda x: x - lax.psum(x, 'i', axis_index_groups=axis_index_groups)
f = self.pmap(f, 'i')
shape = (replicas, 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
def sum_helper(a):
return np.broadcast_to(a.sum(0, keepdims=True),
(replicas // 2, x.shape[1]))
expected_psum_1 = sum_helper(x[:replicas // 2])
expected_psum_2 = sum_helper(x[replicas // 2:])
expected_psum = np.concatenate([expected_psum_1, expected_psum_2], 0)
expected = x - expected_psum
ans = f(x)
self.assertAllClose(ans, expected, check_dtypes=False)
def testGatherReplicaGroups(self):
replicas = jax.device_count()
if replicas % 2 != 0:
raise SkipTest("Test expected an even number of devices greater than 1.")
axis_index_groups = np.arange(replicas, dtype=np.int32)
axis_index_groups = axis_index_groups.reshape((replicas // 2, 2)).T
axis_index_groups = axis_index_groups.tolist()
f = lambda x: lax.all_gather(x, 'i', axis_index_groups=axis_index_groups)
f = self.pmap(f, 'i')
shape = (replicas, 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
ans = f(x)
group_1_result = x[0::2]
group_2_result = x[1::2]
expected = np.empty((replicas, replicas // 2, x.shape[1]))
expected[0::2] = group_1_result
expected[1::2] = group_2_result
self.assertAllClose(ans, expected, check_dtypes=False)
def testGatherReplicaGroupsInterleaved(self):
replicas = jax.device_count()
if replicas % 2 != 0:
raise SkipTest("Test expected an even number of devices greater than 1.")
indexes = np.arange(replicas)
indexes = np.concatenate([indexes[::2], indexes[1::2]])
axis_index_groups = indexes.reshape(2, replicas // 2).tolist()
f = lambda x: lax.all_gather(x, 'i', axis_index_groups=axis_index_groups)
f = self.pmap(f, 'i')
shape = (replicas, 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
ans = f(x)
expected = np.zeros((replicas, replicas // 2, x.shape[1]))
expected[::2] = x[::2]
expected[1::2] = x[1::2]
self.assertAllClose(ans, expected, check_dtypes=False)
2021-03-05 17:59:16 +00:00
@ignore_slow_all_to_all_warning()
def testGradOfGather(self):
@partial(self.pmap, axis_name='i')
def f(x):
return lax.all_gather(x, axis_name='i')
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
jtu.check_grads(f, (x,), 2, ["fwd", "rev"], 1e-2, 1e-2, eps=1.)
def testNestedPmapReplicaGroups(self):
replicas = jax.device_count()
if replicas % 4 != 0:
raise SkipTest
axis_index_groups = np.arange(replicas // 2).reshape(
2, replicas // 4).tolist()
f = lambda x: x - lax.psum(x, 'i', axis_index_groups=axis_index_groups)
f1 = self.pmap(self.pmap(f, 'i'), 'j')
f2 = self.pmap(lambda x: self.pmap(f, 'i')(x) + 1., 'j') # "imperfectly nested" case
f3 = self.pmap(self.pmap(f, 'j'), 'i')
shape = (2, replicas // 2, 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
def sum_helper_f1(a):
return np.broadcast_to(a.sum(1, keepdims=True),
(shape[0], shape[1] // 2, shape[2]))
expected_psum_1 = sum_helper_f1(x[:, :replicas // 4])
expected_psum_2 = sum_helper_f1(x[:, replicas // 4:])
expected_psum = np.concatenate([expected_psum_1, expected_psum_2], 1)
expected = x - expected_psum
ans = f1(x)
self.assertAllClose(ans, expected)
expected = x - expected_psum + 1.
ans = f2(x)
self.assertAllClose(ans, expected)
shape = (replicas // 2, 2, 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
def sum_helper_f3(a):
return np.broadcast_to(a.sum(0, keepdims=True),
(shape[0] // 2, shape[1], shape[2]))
expected_psum_1 = sum_helper_f3(x[:replicas // 4])
expected_psum_2 = sum_helper_f3(x[replicas // 4:])
expected_psum = np.concatenate([expected_psum_1, expected_psum_2], 0)
expected = x - expected_psum
ans = f3(x)
self.assertAllClose(ans, expected)
def testAxisGroups(self):
axis_env = xla.AxisEnv(8, ('i', 'j'), (4, 2))
groups = xla.axis_groups(axis_env, 'i')
self.assertEqual(groups, ((0, 2, 4, 6), (1, 3, 5, 7)))
groups = xla.axis_groups(axis_env, 'j')
self.assertEqual(groups, ((0, 1), (2, 3), (4, 5), (6, 7)))
groups = xla.axis_groups(axis_env, ('i', 'j'))
self.assertEqual(groups, ((0, 1, 2, 3, 4, 5, 6, 7,),))
groups = xla.axis_groups(axis_env, ('j', 'i'))
self.assertEqual(len(groups), 1)
self.assertEqual((tuple(sorted(groups[0])),),
((0, 1, 2, 3, 4, 5, 6, 7,),)) # order doesn't matter
def testCollectivePermute(self):
device_count = jax.device_count()
2019-05-10 14:24:15 -07:00
rotation = [(i, (i + 1) % device_count) for i in range(device_count)]
f = lambda x: lax.ppermute(x, perm=rotation, axis_name='i')
f = self.pmap(f, 'i')
x = jnp.arange(4 * device_count).reshape((device_count, 4))
ans = f(x)
expected = np.roll(x, shift=1, axis=0)
self.assertAllClose(ans, expected, check_dtypes=False)
@jtu.skip_on_devices("cpu")
def testCollectivePermuteGrad(self):
device_count = jax.device_count()
shift_right = [(i, (i + 1)) for i in range(device_count - 1)]
f = lambda x: lax.ppermute(x, perm=shift_right, axis_name='i')
y = np.pi + np.arange(device_count, dtype=np.float32)
g = lambda x: jnp.sum(y * self.pmap(f, 'i')(x))
x = np.arange(device_count, dtype=np.float32)
ans = grad(g)(x)
expected = np.concatenate([np.pi + np.arange(1, device_count), [0]])
self.assertAllClose(ans, expected, check_dtypes=False)
def testCollectivePermuteCyclicGrad(self):
device_count = jax.device_count()
shift_right = [(i, (i + 1) % device_count) for i in range(device_count)]
f = lambda x: lax.ppermute(x, perm=shift_right, axis_name='i')
y = np.pi + np.arange(device_count, dtype=np.float32)
g = lambda x: jnp.sum(y * self.pmap(f, 'i')(x))
x = np.arange(device_count, dtype=np.float32)
ans = grad(g)(x)
expected = np.roll(np.pi + np.arange(device_count), -1)
self.assertAllClose(ans, expected, check_dtypes=False)
jtu.check_grads(g, (x,), 2, ["fwd", "rev"], 1e-2, 1e-2)
2020-01-10 16:49:08 -08:00
def testCollectivePermuteCyclicWithPShuffle(self):
device_count = jax.device_count()
values = np.arange(device_count)
2020-01-10 16:49:08 -08:00
shift_right = [(i - 1) % device_count for i in range(device_count)]
f = lambda x: lax.pshuffle(x, perm=shift_right, axis_name='i')
expected = np.roll(values, 1)
ans = np.asarray(self.pmap(f, "i")(values))
2020-01-10 16:49:08 -08:00
self.assertAllClose(ans, expected, check_dtypes=False)
def testPShuffleWithBadPerm(self):
device_count = jax.device_count()
2020-01-10 16:49:08 -08:00
bad_perm = list(range(device_count))
bad_perm[0] = 1
f = lambda x: lax.pshuffle(x, perm=bad_perm, axis_name='i')
g = lambda: self.pmap(f, "i")(np.arange(device_count))
2020-01-10 16:49:08 -08:00
self.assertRaisesRegex(
2020-07-07 13:21:44 -07:00
ValueError,
2020-07-07 13:19:19 -07:00
"`perm` does not represent a permutation: \\[1.*\\]", g)
2020-01-10 16:49:08 -08:00
2019-11-16 14:40:25 -08:00
def testPpermuteWithZipObject(self):
# https://github.com/google/jax/issues/1703
num_devices = jax.device_count()
2019-11-15 14:33:39 -08:00
perm = [num_devices - 1] + list(range(num_devices - 1))
f = self.pmap(lambda x: lax.ppermute(x, "i", zip(perm, range(num_devices))), "i")
result = f(jnp.arange(num_devices, dtype=jnp.float32))
expected = jnp.asarray(perm, dtype=jnp.float32)
self.assertAllClose(result, expected)
2019-11-15 14:33:39 -08:00
def testRule30(self):
# This is a test of collective_permute implementing a simple halo exchange
# to run a rule 30 simulation: https://en.wikipedia.org/wiki/Rule_30
# Halo exchange should be useful in spatially-sharded convolutions and in
# other simulations.
device_count = jax.device_count()
def send_right(x, axis_name):
left_perm = [(i, (i + 1) % device_count) for i in range(device_count)]
return lax.ppermute(x, perm=left_perm, axis_name=axis_name)
def send_left(x, axis_name):
left_perm = [((i + 1) % device_count, i) for i in range(device_count)]
return lax.ppermute(x, perm=left_perm, axis_name=axis_name)
def update_board(board):
left = board[:-2]
right = board[2:]
center = board[1:-1]
return lax.bitwise_xor(left, lax.bitwise_or(center, right))
@partial(self.pmap, axis_name='i')
def step(board_slice):
left, right = board_slice[:1], board_slice[-1:]
right, left = send_left(left, 'i'), send_right(right, 'i')
enlarged_board_slice = jnp.concatenate([left, board_slice, right])
return update_board(enlarged_board_slice)
board = np.zeros(40, dtype=bool)
board[board.shape[0] // 2] = True
reshaped_board = board.reshape((device_count, -1))
boards = []
def print_board(board):
boards.append(''.join('*' if x else ' ' for x in board.ravel()))
print_board(reshaped_board)
for _ in range(20):
reshaped_board = step(reshaped_board)
print_board(reshaped_board)
ans = '\n'.join(boards)
expected = '\n'.join((
' * ',
' *** ',
' ** * ',
' ** **** ',
' ** * * ',
' ** **** *** ',
' ** * * * ',
' ** **** ****** ',
' ** * *** * ',
' ** **** ** * *** ',
' ** * * **** ** * ',
' ** **** ** * * **** ',
' ** * *** ** ** * * ',
' ** **** ** *** *** ** *** ',
' ** * * *** * *** * * ',
' ** **** ** * * ***** ******* ',
' ** * *** **** * *** * ',
' ** **** ** *** ** ** * *** ',
' ** * * *** * ** *** **** ** * ',
' ** **** ** * ****** * * *** ****',
' * * *** **** **** *** ** * ',
))
print(ans)
self.assertEqual(ans, expected)
def testReduceMax(self):
f = self.pmap(lambda x: x - lax.pmax(x, 'i'), axis_name='i')
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
expected = x - np.max(x, 0)
ans = f(x)
self.assertAllClose(ans, expected, check_dtypes=False)
def testReduceMin(self):
f = self.pmap(lambda x: x - lax.pmin(x, 'i'), axis_name='i')
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
expected = x - np.min(x, 0)
ans = f(x)
self.assertAllClose(ans, expected, check_dtypes=False)
def testDeviceCountError(self):
device_count = jax.device_count()
f = self.pmap(lambda x: x)
x = jnp.arange(device_count + 1)
self.assertRaisesRegex(ValueError, ".*requires.*replicas", lambda: f(x))
f = self.pmap(lambda x: x)
x = np.ones((device_count + 1, 10))
self.assertRaisesRegex(ValueError, ".*requires.*replicas", lambda: f(x))
f = self.pmap(lambda x: self.pmap(lambda x: x)(x))
x = np.ones((device_count, 2, 10))
self.assertRaisesRegex(ValueError, ".*requires.*replicas", lambda: f(x))
def testPmapConstant(self):
device_count = jax.device_count()
f = self.pmap(lambda x: 3)
x = jnp.arange(device_count)
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
ans = f(x)
# self.assertEqual(count[0], 0) # TODO(mattjj): fix this
expected = np.repeat(3, device_count)
self.assertAllClose(ans, expected, check_dtypes=False)
f = self.pmap(lambda x: (x, 3))
x = np.arange(device_count)
with jtu.assert_num_jit_and_pmap_compilations(1):
_, ans = f(x)
self.assertAllClose(ans, expected, check_dtypes=False)
def testPmapConstantDevices(self):
if jax.device_count() == 1:
raise SkipTest("this test requires multiple devices")
devices = jax.devices()[:-1]
shuffle(devices)
f = self.pmap(lambda x: 3, devices=devices)
x = jnp.arange(len(devices))
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
ans = f(x)
# self.assertEqual(count[0], 0) # TODO(mattjj): don't compile for constants
expected = np.repeat(3, len(devices))
self.assertAllClose(ans, expected, check_dtypes=False)
# Test that 'ans' was properly replicated across devices.
self.assertEqual([b.device() for b in ans.device_buffers], devices)
def testPmapConstantError(self):
device_count = jax.device_count()
f = self.pmap(lambda x: 3)
x = jnp.arange(device_count + 1)
2021-03-29 13:58:04 -07:00
self.assertRaisesRegex(
ValueError,
(r"compiling computation that requires \d+ logical devices, "
r"but only \d+ XLA devices are available .*"),
lambda: f(x))
2021-03-29 13:58:04 -07:00
# TODO(mattjj): test error message with explicit devices
# f = pmap(lambda x: 3, devices=[jax.devices()[0]])
2021-03-29 13:58:04 -07:00
# x = jnp.arange(2)
# self.assertRaisesRegex(
# ValueError, r"Cannot replicate across \d+ replicas because only \d+ "
# r"local devices are available.", lambda: f(x))
def testNestedPmapConstant(self):
if jax.device_count() == 1:
raise SkipTest("this test requires multiple devices")
f = self.pmap(self.pmap(lambda x: 3))
shape = (2, jax.device_count() // 2, 3)
x = jnp.arange(prod(shape)).reshape(shape)
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
ans = f(x)
# self.assertEqual(count[0], 0) # TODO(mattjj): don't compile for constants
expected = 3 * np.ones(shape[:2])
self.assertAllClose(ans, expected, check_dtypes=False)
# Test that 'ans' was properly replicated across devices.
expected_sharded = self.pmap(self.pmap(lambda x: x))(expected)
self.assertEqual([b.device() for b in ans.device_buffers],
[b.device() for b in expected_sharded.device_buffers])
f = self.pmap(self.pmap(lambda x: (x, 3)))
x_sharded, ans = f(x)
self.assertAllClose(ans, expected, check_dtypes=False)
self.assertEqual([b.device() for b in ans.device_buffers],
[b.device() for b in x_sharded.device_buffers])
@unittest.skip("Nested pmaps with devices not yet implemented")
def testNestedPmapConstantDevices(self):
if jax.device_count() < 6:
raise SkipTest("this test requires >= 6 devices")
devices = jax.devices()[:-2]
shuffle(devices)
f = self.pmap(self.pmap(lambda x: 3), devices=devices)
shape = (2, len(devices) // 2, 3)
x = jnp.arange(prod(shape)).reshape(shape)
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
ans = f(x)
# self.assertEqual(count[0], 0) # TODO(mattjj): don't compile for constants
expected = 3 * np.ones(shape[:2])
self.assertAllClose(ans, expected, check_dtypes=False)
# Test that 'ans' was properly replicated across devices.
expected_sharded = self.pmap(self.pmap(lambda x: x), devices=devices)(expected)
self.assertEqual([b.device() for b in ans.device_buffers],
[b.device() for b in expected_sharded.device_buffers])
def testNestedPmapConstantError(self):
f = self.pmap(self.pmap(lambda x: 3))
shape = (2, jax.device_count() // 2 + 1, 3)
x = jnp.arange(prod(shape)).reshape(shape)
2021-03-29 13:58:04 -07:00
self.assertRaisesRegex(
ValueError,
(r"compiling computation that requires \d+ logical devices, "
r"but only \d+ XLA devices are available .*"),
lambda: f(x))
# TODO(mattjj): check error message with explicit devices
# if jax.device_count() > 1:
# f = pmap(pmap(lambda x: 3), devices=jax.devices()[:-1])
# shape = (2, jax.device_count() // 2, 3)
2021-03-29 13:58:04 -07:00
# x = jnp.arange(prod(shape)).reshape(shape)
# self.assertRaisesRegex(
# ValueError,
# (r"compiling computation that requires \d+ replicas, "
# r"but only \d+ XLA devices are available"),
# lambda: f(x))
def testCollectiveConstant(self):
device_count = jax.device_count()
f = self.pmap(lambda x: lax.psum(1, 'i'), 'i')
x = jnp.arange(device_count)
ans = f(x)
expected = np.repeat(device_count, device_count)
self.assertAllClose(ans, expected, check_dtypes=False)
def testCollectiveConstantNested(self):
device_count = jax.device_count()
@partial(self.pmap, axis_name='i')
def f(x):
@partial(self.pmap, axis_name='j')
def g(y):
a = lax.psum(1, 'i')
b = lax.psum(1, 'j')
c = lax.psum(1, ('i', 'j'))
return a, b, c
return g(x)
shape = (device_count, 1, 4)
x = jnp.arange(prod(shape)).reshape(shape)
a, b, c = f(x)
self.assertEqual(a.shape, shape[:-1])
self.assertEqual(b.shape, shape[:-1])
self.assertEqual(c.shape, shape[:-1])
self.assertEqual(a.ravel()[0], device_count)
self.assertEqual(b.ravel()[0], 1)
self.assertEqual(c.ravel()[0], device_count * 1)
def testAxisIndex(self):
device_count = jax.device_count()
f = self.pmap(lambda x: x + lax.axis_index('i'), 'i')
x = jnp.ones(device_count)
ans = f(x)
expected = 1 + np.arange(device_count)
self.assertAllClose(ans, expected, check_dtypes=False)
def testAxisIndexNestedPmap(self):
device_count = jax.device_count()
if device_count < 4:
raise SkipTest("test requires at least four devices")
f = lambda axis: self.pmap(self.pmap(lambda x: x + lax.axis_index(axis), 'j'), 'i')
x = jnp.ones((2, 2))
expected_j = np.broadcast_to(1 + np.arange(2), (2, 2))
self.assertAllClose(f('j')(x), expected_j, check_dtypes=False)
self.assertAllClose(f('i')(x), expected_j.T, check_dtypes=False)
def testAxisIndexNd(self):
device_count = jax.device_count()
if device_count < 4:
raise SkipTest("test requires at least four devices")
f = lambda axes: self.pmap(self.pmap(lambda x: x + lax.axis_index(axes), 'j'), 'i')
x = jnp.ones((2, 2))
expected = 1 + np.arange(4).reshape((2, 2))
self.assertAllClose(f(('i', 'j'))(x), expected, check_dtypes=False)
self.assertAllClose(f(('j', 'i'))(x), expected.T, check_dtypes=False)
def testAxisIndexInInitialStyle(self):
@partial(self.pmap, axis_name='i')
def f(x):
def body(carry, i):
return carry + i + lax.axis_index('i'), None
return lax.scan(body, 0, x)[0]
device_count = jax.device_count()
shape = (device_count, 10)
self.assertAllClose(f(jnp.ones(shape, dtype=int)),
(np.arange(device_count) + 1) * 10)
def testVmapOfPmap(self):
device_count = jax.device_count()
f0 = lambda x: x
f1 = self.pmap(f0, axis_name='i')
ax = self.rng().randn(2, device_count, 50, 60)
bx = vmap(f1)(ax)
self.assertAllClose(ax, bx, check_dtypes=False)
def testVmapOfPmap2(self):
N_DEVICES = jax.device_count()
keys = random.split(random.PRNGKey(1), 13) # [13, 2]
@self.pmap
def g(key):
_ = random.normal(key, ())
return 0.
@vmap
def s(keys):
keys = tree_util.tree_map(
lambda x: jnp.broadcast_to(x, (N_DEVICES,) + x.shape),
keys)
return g(keys)
2019-09-11 06:22:25 -07:00
ans = s(keys) # doesn't crash
self.assertEqual(ans.shape, (13, N_DEVICES))
def testVmapOfPmap3(self):
# https://github.com/google/jax/issues/3399
device_count = jax.device_count()
if device_count < 2:
raise SkipTest("test requires at least two devices")
def map_version(qs, pts):
return jax.lax.map(lambda x: func(x, pts), qs)
def vmap_version(qs, pts):
return jax.vmap(func, in_axes=(0, None))(qs, pts)
def func(q, pts):
q_from_pmap = self.pmap(lambda x, y: y, in_axes=(0, None))(pts, q)
return q, q_from_pmap
pts = jnp.ones(device_count)
qs = jnp.asarray(((0,0), (3,3), (2,2)))
with ignore_jit_of_pmap_warning():
_, expected = map_version(qs, pts)
_, ans = vmap_version(qs, pts)
self.assertAllClose(ans, expected, check_dtypes=False)
def testVmapOfPmapNonLeadingAxis(self):
device_count = jax.device_count()
f0 = lambda x: x
f1 = self.pmap(f0, axis_name='i')
ax = self.rng().randn(device_count, 2, 50, 60)
bx = vmap(f1, in_axes=2, out_axes=2)(ax)
self.assertAllClose(ax, bx, check_dtypes=False)
def testVmapOfPmapTuple(self):
device_count = jax.device_count()
f0 = lambda *x: x
f1 = self.pmap(f0, axis_name='i')
ax = self.rng().randn(device_count, 2, 50, 60)
ay = self.rng().randn(device_count, 30, 2)
az1 = self.rng().randn(device_count, 20)
az2 = self.rng().randn(2, device_count, 20)
bx, by, bz = vmap(f1, in_axes=(1, 2, (None, 0)), out_axes=(1, 2, 0))(ax, ay, (az1, az2))
self.assertAllClose(ax, bx, check_dtypes=False)
self.assertAllClose(ay, by, check_dtypes=False)
bz1, bz2 = bz
expected_bz1 = np.broadcast_to(az1, (2,) + az1.shape)
self.assertAllClose(expected_bz1, bz1, check_dtypes=False)
self.assertAllClose(bz2, bz2, check_dtypes=False)
2021-03-05 17:59:16 +00:00
@ignore_slow_all_to_all_warning()
2019-06-08 08:57:34 -07:00
def testPswapaxes(self):
device_count = jax.device_count()
2019-06-08 08:57:34 -07:00
shape = (device_count, 3, device_count, 5)
x = np.arange(prod(shape)).reshape(shape)
2019-06-08 08:57:34 -07:00
ans = self.pmap(lambda x: lax.pswapaxes(x, 'i', 1), axis_name='i')(x)
expected = np.swapaxes(x, 0, 2)
2019-06-08 08:57:34 -07:00
self.assertAllClose(ans, expected, check_dtypes=False)
2021-03-05 17:59:16 +00:00
@ignore_slow_all_to_all_warning()
def testGradOfPswapaxes(self):
device_count = jax.device_count()
shape = (device_count, 1, device_count)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
w = np.arange(device_count, dtype=np.float32)
@partial(self.pmap, axis_name='i')
def f(x, w):
g = lambda x: jnp.sum(lax.pswapaxes(x, 'i', 1) * w)
return grad(g)(x)
ans = f(x, w)
expected = np.tile(w, reps=device_count).reshape(shape)
self.assertAllClose(ans, expected, check_dtypes=False)
@ignore_slow_all_to_all_warning()
def testAllToAllReplicaGroups(self):
# If num_devices = 4, these would be the inputs/outputs:
# input = [[0, 1], [2, 3], [4, 5], [6, 7]]
# axis_index_groups = [[0, 2], [1, 3]]
# output = [[0, 4], [2, 6], [1, 5], [3, 7]]
#
# This is essentially like splitting the number of rows in the input in two
# groups of rows, and swaping the two inner axes (axis=1 and axis=2), which
# is exactly what the test case checks.
device_count = jax.device_count()
if device_count % 2 != 0:
raise SkipTest('test requires an even number of devices')
shape = (device_count, device_count // 2)
x = np.arange(prod(shape)).reshape(shape)
axis_index_groups = np.arange(device_count, dtype=np.int32)
axis_index_groups = axis_index_groups.reshape((device_count // 2, 2)).T
axis_index_groups = axis_index_groups.tolist()
@partial(self.pmap, axis_name='i')
def fn(x):
return lax.all_to_all(x, 'i', 0, 0, axis_index_groups=axis_index_groups)
expected = np.swapaxes(
x.reshape((device_count // 2, 2, device_count // 2)),
0, 2).reshape(shape)
self.assertAllClose(fn(x), expected, check_dtypes=False)
@ignore_slow_all_to_all_warning()
def testGradOfAllToAllReplicaGroups(self):
device_count = jax.device_count()
if device_count % 2 != 0:
raise SkipTest('test requires an even number of devices')
shape = (device_count, device_count // 2, 1)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
w = np.arange(device_count, dtype=np.float32)
axis_index_groups = np.arange(device_count, dtype=np.int32)
axis_index_groups = axis_index_groups.reshape((2, device_count // 2))
axis_index_groups = axis_index_groups.tolist()
@partial(self.pmap, axis_name='i')
def fn(x, w):
g = lambda x: jnp.sum(lax.all_to_all(x, 'i', 0, 1, axis_index_groups=axis_index_groups) * w)
return grad(g)(x)
expected = np.ones_like(x) * w[:, np.newaxis, np.newaxis]
expected = np.swapaxes(
expected.reshape((2, device_count // 2, device_count // 2)),
1, 2).reshape(shape)
self.assertAllClose(fn(x, w), expected, check_dtypes=False)
Allow ShardedDeviceArrays to represent arbitrary data shardings. (#2142) This change introduces ShardingSpec, a struct describing how an array should be sharded. This is integrated into ShardedDeviceArray to allow more flexible sharding. It supports partitioning (both "pmap-style", where an entire axis is decomposed into separate shards and doesn't appear in the on-device shape at all, and "sharded_jit-style", where an axis is chunked into shards but remains in the on-device shape) and replication. This removes the need for ChunkedDeviceArray, since a ShardedDeviceArray can now represent chunks. Here are pmap_benchmark times showing that the overall effect of this change neutral to positive (integer indexing is much faster!). **pmap_shard_args** ``` ---------Benchmark summary for pmap_shard_args--------- nargs nshards mean %std relative mean/baseline ------- --------- --------- --------- ---------- --------------- 10 8 0.041855 4.15223 1 1.01466 100 8 0.129884 4.85321 3.1032 0.988543 101 8 0.136347 6.20233 3.2576 0.967138 500 8 0.533207 3.6815 12.7394 1.0294 1000 8 1.10338 0.525193 26.362 0.960435 5000 8 5.33911 0 127.562 0.963319 100 2 0.0638619 10.7069 1.52579 1.0362 100 4 0.0868253 6.76701 2.07443 0.967323 100 8 0.128151 6.46004 3.06177 0.979742 100 100 1.22631 1.94885 29.299 1.00371 100 500 6.60746 0 157.865 0.956657 ``` **pmap_shard_outputs** ``` nouts nshards mean %std relative mean/baseline ------- --------- ---------- --------- ---------- --------------- 10 8 0.0664526 9.49251 1 0.938466 100 8 0.195711 2.19429 2.94512 1.04239 500 8 0.82577 0.330864 12.4265 0.994669 1000 8 1.68323 1.0516 25.3298 0.966915 5000 8 8.89032 0 133.784 0.998038 100 2 0.074806 10.1734 1.12571 0.980254 100 4 0.121334 5.76774 1.82588 1.02033 100 8 0.185253 5.45068 2.78775 1.01666 100 100 2.37076 0 35.6759 1.08629 100 500 17.0832 0 257.074 0.976879 ``` **ShardedDeviceArray_indexing** ``` indices_fn mean %std relative mean/baseline ------------------ ---------- ------- ---------- --------------- integer_indices 0.0603473 8.29159 1 0.359496 integer_2D_indices 18.0241 0 298.672 1.00583 ``` This is how I ran the benchmark: ``` TARGET_TOTAL_SECS=2 CUDA_VISIBLE_DEVICES= XLA_FLAGS=--xla_force_host_platform_device_count=500 python3 benchmarks/pmap_benchmark.py --baseline_dir=<results as of a3cc9a7> ```
2020-04-15 12:43:55 -07:00
def testReshardInput(self):
if jax.device_count() < 6:
Allow ShardedDeviceArrays to represent arbitrary data shardings. (#2142) This change introduces ShardingSpec, a struct describing how an array should be sharded. This is integrated into ShardedDeviceArray to allow more flexible sharding. It supports partitioning (both "pmap-style", where an entire axis is decomposed into separate shards and doesn't appear in the on-device shape at all, and "sharded_jit-style", where an axis is chunked into shards but remains in the on-device shape) and replication. This removes the need for ChunkedDeviceArray, since a ShardedDeviceArray can now represent chunks. Here are pmap_benchmark times showing that the overall effect of this change neutral to positive (integer indexing is much faster!). **pmap_shard_args** ``` ---------Benchmark summary for pmap_shard_args--------- nargs nshards mean %std relative mean/baseline ------- --------- --------- --------- ---------- --------------- 10 8 0.041855 4.15223 1 1.01466 100 8 0.129884 4.85321 3.1032 0.988543 101 8 0.136347 6.20233 3.2576 0.967138 500 8 0.533207 3.6815 12.7394 1.0294 1000 8 1.10338 0.525193 26.362 0.960435 5000 8 5.33911 0 127.562 0.963319 100 2 0.0638619 10.7069 1.52579 1.0362 100 4 0.0868253 6.76701 2.07443 0.967323 100 8 0.128151 6.46004 3.06177 0.979742 100 100 1.22631 1.94885 29.299 1.00371 100 500 6.60746 0 157.865 0.956657 ``` **pmap_shard_outputs** ``` nouts nshards mean %std relative mean/baseline ------- --------- ---------- --------- ---------- --------------- 10 8 0.0664526 9.49251 1 0.938466 100 8 0.195711 2.19429 2.94512 1.04239 500 8 0.82577 0.330864 12.4265 0.994669 1000 8 1.68323 1.0516 25.3298 0.966915 5000 8 8.89032 0 133.784 0.998038 100 2 0.074806 10.1734 1.12571 0.980254 100 4 0.121334 5.76774 1.82588 1.02033 100 8 0.185253 5.45068 2.78775 1.01666 100 100 2.37076 0 35.6759 1.08629 100 500 17.0832 0 257.074 0.976879 ``` **ShardedDeviceArray_indexing** ``` indices_fn mean %std relative mean/baseline ------------------ ---------- ------- ---------- --------------- integer_indices 0.0603473 8.29159 1 0.359496 integer_2D_indices 18.0241 0 298.672 1.00583 ``` This is how I ran the benchmark: ``` TARGET_TOTAL_SECS=2 CUDA_VISIBLE_DEVICES= XLA_FLAGS=--xla_force_host_platform_device_count=500 python3 benchmarks/pmap_benchmark.py --baseline_dir=<results as of a3cc9a7> ```
2020-04-15 12:43:55 -07:00
raise SkipTest("testReshardInput requires 6 devices")
# Manually construct a ShardedDeviceArray with the wrong sharding for the
# subsequent pmap
shard_shape = (3,2)
shard = jnp.arange(prod(shard_shape)).reshape(shard_shape)
bufs = pxla.device_put(shard, jax.devices()[:4], replicate=True)
Allow ShardedDeviceArrays to represent arbitrary data shardings. (#2142) This change introduces ShardingSpec, a struct describing how an array should be sharded. This is integrated into ShardedDeviceArray to allow more flexible sharding. It supports partitioning (both "pmap-style", where an entire axis is decomposed into separate shards and doesn't appear in the on-device shape at all, and "sharded_jit-style", where an axis is chunked into shards but remains in the on-device shape) and replication. This removes the need for ChunkedDeviceArray, since a ShardedDeviceArray can now represent chunks. Here are pmap_benchmark times showing that the overall effect of this change neutral to positive (integer indexing is much faster!). **pmap_shard_args** ``` ---------Benchmark summary for pmap_shard_args--------- nargs nshards mean %std relative mean/baseline ------- --------- --------- --------- ---------- --------------- 10 8 0.041855 4.15223 1 1.01466 100 8 0.129884 4.85321 3.1032 0.988543 101 8 0.136347 6.20233 3.2576 0.967138 500 8 0.533207 3.6815 12.7394 1.0294 1000 8 1.10338 0.525193 26.362 0.960435 5000 8 5.33911 0 127.562 0.963319 100 2 0.0638619 10.7069 1.52579 1.0362 100 4 0.0868253 6.76701 2.07443 0.967323 100 8 0.128151 6.46004 3.06177 0.979742 100 100 1.22631 1.94885 29.299 1.00371 100 500 6.60746 0 157.865 0.956657 ``` **pmap_shard_outputs** ``` nouts nshards mean %std relative mean/baseline ------- --------- ---------- --------- ---------- --------------- 10 8 0.0664526 9.49251 1 0.938466 100 8 0.195711 2.19429 2.94512 1.04239 500 8 0.82577 0.330864 12.4265 0.994669 1000 8 1.68323 1.0516 25.3298 0.966915 5000 8 8.89032 0 133.784 0.998038 100 2 0.074806 10.1734 1.12571 0.980254 100 4 0.121334 5.76774 1.82588 1.02033 100 8 0.185253 5.45068 2.78775 1.01666 100 100 2.37076 0 35.6759 1.08629 100 500 17.0832 0 257.074 0.976879 ``` **ShardedDeviceArray_indexing** ``` indices_fn mean %std relative mean/baseline ------------------ ---------- ------- ---------- --------------- integer_indices 0.0603473 8.29159 1 0.359496 integer_2D_indices 18.0241 0 298.672 1.00583 ``` This is how I ran the benchmark: ``` TARGET_TOTAL_SECS=2 CUDA_VISIBLE_DEVICES= XLA_FLAGS=--xla_force_host_platform_device_count=500 python3 benchmarks/pmap_benchmark.py --baseline_dir=<results as of a3cc9a7> ```
2020-04-15 12:43:55 -07:00
aval = ShapedArray((6,4), shard.dtype)
sharding_spec = pxla.ShardingSpec(
sharding=map(pxla.Chunked, ([2], [2])),
mesh_mapping=map(pxla.ShardedAxis, (0, 1)))
arr = pxla.make_sharded_device_array(aval, sharding_spec, bufs)
Allow ShardedDeviceArrays to represent arbitrary data shardings. (#2142) This change introduces ShardingSpec, a struct describing how an array should be sharded. This is integrated into ShardedDeviceArray to allow more flexible sharding. It supports partitioning (both "pmap-style", where an entire axis is decomposed into separate shards and doesn't appear in the on-device shape at all, and "sharded_jit-style", where an axis is chunked into shards but remains in the on-device shape) and replication. This removes the need for ChunkedDeviceArray, since a ShardedDeviceArray can now represent chunks. Here are pmap_benchmark times showing that the overall effect of this change neutral to positive (integer indexing is much faster!). **pmap_shard_args** ``` ---------Benchmark summary for pmap_shard_args--------- nargs nshards mean %std relative mean/baseline ------- --------- --------- --------- ---------- --------------- 10 8 0.041855 4.15223 1 1.01466 100 8 0.129884 4.85321 3.1032 0.988543 101 8 0.136347 6.20233 3.2576 0.967138 500 8 0.533207 3.6815 12.7394 1.0294 1000 8 1.10338 0.525193 26.362 0.960435 5000 8 5.33911 0 127.562 0.963319 100 2 0.0638619 10.7069 1.52579 1.0362 100 4 0.0868253 6.76701 2.07443 0.967323 100 8 0.128151 6.46004 3.06177 0.979742 100 100 1.22631 1.94885 29.299 1.00371 100 500 6.60746 0 157.865 0.956657 ``` **pmap_shard_outputs** ``` nouts nshards mean %std relative mean/baseline ------- --------- ---------- --------- ---------- --------------- 10 8 0.0664526 9.49251 1 0.938466 100 8 0.195711 2.19429 2.94512 1.04239 500 8 0.82577 0.330864 12.4265 0.994669 1000 8 1.68323 1.0516 25.3298 0.966915 5000 8 8.89032 0 133.784 0.998038 100 2 0.074806 10.1734 1.12571 0.980254 100 4 0.121334 5.76774 1.82588 1.02033 100 8 0.185253 5.45068 2.78775 1.01666 100 100 2.37076 0 35.6759 1.08629 100 500 17.0832 0 257.074 0.976879 ``` **ShardedDeviceArray_indexing** ``` indices_fn mean %std relative mean/baseline ------------------ ---------- ------- ---------- --------------- integer_indices 0.0603473 8.29159 1 0.359496 integer_2D_indices 18.0241 0 298.672 1.00583 ``` This is how I ran the benchmark: ``` TARGET_TOTAL_SECS=2 CUDA_VISIBLE_DEVICES= XLA_FLAGS=--xla_force_host_platform_device_count=500 python3 benchmarks/pmap_benchmark.py --baseline_dir=<results as of a3cc9a7> ```
2020-04-15 12:43:55 -07:00
r = self.pmap(lambda x: x + 1)(arr)
self.assertAllClose(r, arr + 1)
Allow ShardedDeviceArrays to represent arbitrary data shardings. (#2142) This change introduces ShardingSpec, a struct describing how an array should be sharded. This is integrated into ShardedDeviceArray to allow more flexible sharding. It supports partitioning (both "pmap-style", where an entire axis is decomposed into separate shards and doesn't appear in the on-device shape at all, and "sharded_jit-style", where an axis is chunked into shards but remains in the on-device shape) and replication. This removes the need for ChunkedDeviceArray, since a ShardedDeviceArray can now represent chunks. Here are pmap_benchmark times showing that the overall effect of this change neutral to positive (integer indexing is much faster!). **pmap_shard_args** ``` ---------Benchmark summary for pmap_shard_args--------- nargs nshards mean %std relative mean/baseline ------- --------- --------- --------- ---------- --------------- 10 8 0.041855 4.15223 1 1.01466 100 8 0.129884 4.85321 3.1032 0.988543 101 8 0.136347 6.20233 3.2576 0.967138 500 8 0.533207 3.6815 12.7394 1.0294 1000 8 1.10338 0.525193 26.362 0.960435 5000 8 5.33911 0 127.562 0.963319 100 2 0.0638619 10.7069 1.52579 1.0362 100 4 0.0868253 6.76701 2.07443 0.967323 100 8 0.128151 6.46004 3.06177 0.979742 100 100 1.22631 1.94885 29.299 1.00371 100 500 6.60746 0 157.865 0.956657 ``` **pmap_shard_outputs** ``` nouts nshards mean %std relative mean/baseline ------- --------- ---------- --------- ---------- --------------- 10 8 0.0664526 9.49251 1 0.938466 100 8 0.195711 2.19429 2.94512 1.04239 500 8 0.82577 0.330864 12.4265 0.994669 1000 8 1.68323 1.0516 25.3298 0.966915 5000 8 8.89032 0 133.784 0.998038 100 2 0.074806 10.1734 1.12571 0.980254 100 4 0.121334 5.76774 1.82588 1.02033 100 8 0.185253 5.45068 2.78775 1.01666 100 100 2.37076 0 35.6759 1.08629 100 500 17.0832 0 257.074 0.976879 ``` **ShardedDeviceArray_indexing** ``` indices_fn mean %std relative mean/baseline ------------------ ---------- ------- ---------- --------------- integer_indices 0.0603473 8.29159 1 0.359496 integer_2D_indices 18.0241 0 298.672 1.00583 ``` This is how I ran the benchmark: ``` TARGET_TOTAL_SECS=2 CUDA_VISIBLE_DEVICES= XLA_FLAGS=--xla_force_host_platform_device_count=500 python3 benchmarks/pmap_benchmark.py --baseline_dir=<results as of a3cc9a7> ```
2020-04-15 12:43:55 -07:00
self.assertEqual(len(r.device_buffers), 6)
2021-01-26 19:38:40 -08:00
@ignore_xmap_warning()
def testSoftPmapBatchMatmul(self):
n = 4 * jax.device_count()
xs = np.arange(n * 2 * 3).reshape(n, 2, 3)
ys = np.arange(n * 3 * 4).reshape(n, 3, 4)
ans = soft_pmap(jnp.dot, 'i')(xs, ys)
expected = np.einsum('nij,njk->nik', xs, ys)
self.assertAllClose(ans, expected, check_dtypes=False)
2021-01-26 19:38:40 -08:00
@ignore_xmap_warning()
def testSoftPmapBatchMatmulJit(self):
n = 4 * jax.device_count()
xs = np.arange(n * 2 * 3).reshape(n, 2, 3)
ys = np.arange(n * 3 * 4).reshape(n, 3, 4)
ans = soft_pmap(jit(jnp.dot), 'i')(xs, ys)
expected = np.einsum('nij,njk->nik', xs, ys)
self.assertAllClose(ans, expected, check_dtypes=False)
2021-01-26 19:38:40 -08:00
@ignore_xmap_warning()
def testSoftPmapPsumConstant(self):
n = 4 * jax.device_count()
def f(_):
return lax.psum(1, 'i')
ans = soft_pmap(f, 'i')(jnp.ones(n))
expected = n * np.ones(n)
self.assertAllClose(ans, expected, check_dtypes=False)
2021-01-26 19:38:40 -08:00
@ignore_xmap_warning()
2019-06-23 16:41:59 -07:00
def testSoftPmapPsum(self):
n = 4 * jax.device_count()
2019-06-23 16:41:59 -07:00
def f(x):
return x / lax.psum(x, 'i')
ans = soft_pmap(f, 'i')(jnp.ones(n))
expected = np.ones(n) / n
2019-06-23 16:41:59 -07:00
self.assertAllClose(ans, expected, check_dtypes=False)
2021-01-26 19:38:40 -08:00
@ignore_xmap_warning()
2019-06-23 16:41:59 -07:00
def testSoftPmapAxisIndex(self):
n = 4 * jax.device_count()
2019-06-23 16:41:59 -07:00
def f(x):
return x * lax.axis_index('i')
ans = soft_pmap(f, 'i')(2 * jnp.ones(n))
expected = 2 * np.arange(n)
2019-06-23 16:41:59 -07:00
self.assertAllClose(ans, expected, check_dtypes=False)
2021-01-26 19:38:40 -08:00
@ignore_xmap_warning()
2019-06-23 16:41:59 -07:00
def testSoftPmapOfJit(self):
n = 4 * jax.device_count()
2019-06-23 16:41:59 -07:00
def f(x):
return 3 * x
ans = soft_pmap(jit(f), 'i')(np.arange(n))
expected = 3 * np.arange(n)
2019-06-23 16:41:59 -07:00
self.assertAllClose(ans, expected, check_dtypes=False)
2021-01-26 19:38:40 -08:00
@ignore_xmap_warning()
@unittest.skip("not implemented") # TODO(mattjj): re-implement
2019-06-23 16:41:59 -07:00
def testSoftPmapNested(self):
n = 4 * jax.device_count()
2019-06-23 16:41:59 -07:00
@partial(soft_pmap, axis_name='i')
@partial(soft_pmap, axis_name='j')
def f(x):
i_size = lax.psum(1, 'i')
return x + lax.axis_index('i') + i_size * lax.axis_index('j')
ans = f(jnp.zeros((n, n)))
expected = np.arange(n ** 2).reshape(n, n).T
2019-06-23 16:41:59 -07:00
self.assertAllClose(ans, expected, check_dtypes=False)
2021-01-26 19:38:40 -08:00
@ignore_xmap_warning()
@unittest.skip("not implemented") # TODO(mattjj): re-implement
2019-06-23 16:41:59 -07:00
def testGradOfSoftPmap(self):
n = 4 * jax.device_count()
2019-06-23 16:41:59 -07:00
@partial(soft_pmap, axis_name='i')
def f(x):
return x * lax.axis_index('i')
ans = grad(lambda x: jnp.sum(f(x)))(jnp.zeros((n, n)))
expected = np.repeat(np.arange(n)[:, None], n, axis=1)
2019-06-23 16:41:59 -07:00
self.assertAllClose(ans, expected, check_dtypes=False)
2021-01-26 19:38:40 -08:00
@ignore_xmap_warning()
2019-07-06 10:00:08 -07:00
def testSoftPmapDevicePersistence(self):
device_count = jax.device_count()
2019-07-06 10:00:08 -07:00
shape = (2 * 2 * device_count, 2, 3)
# check that we can maintain device persistence across calls
x = np.arange(prod(shape)).reshape(shape)
2019-07-06 10:00:08 -07:00
x = soft_pmap(lambda x: x)(x)
Allow ShardedDeviceArrays to represent arbitrary data shardings. (#2142) This change introduces ShardingSpec, a struct describing how an array should be sharded. This is integrated into ShardedDeviceArray to allow more flexible sharding. It supports partitioning (both "pmap-style", where an entire axis is decomposed into separate shards and doesn't appear in the on-device shape at all, and "sharded_jit-style", where an axis is chunked into shards but remains in the on-device shape) and replication. This removes the need for ChunkedDeviceArray, since a ShardedDeviceArray can now represent chunks. Here are pmap_benchmark times showing that the overall effect of this change neutral to positive (integer indexing is much faster!). **pmap_shard_args** ``` ---------Benchmark summary for pmap_shard_args--------- nargs nshards mean %std relative mean/baseline ------- --------- --------- --------- ---------- --------------- 10 8 0.041855 4.15223 1 1.01466 100 8 0.129884 4.85321 3.1032 0.988543 101 8 0.136347 6.20233 3.2576 0.967138 500 8 0.533207 3.6815 12.7394 1.0294 1000 8 1.10338 0.525193 26.362 0.960435 5000 8 5.33911 0 127.562 0.963319 100 2 0.0638619 10.7069 1.52579 1.0362 100 4 0.0868253 6.76701 2.07443 0.967323 100 8 0.128151 6.46004 3.06177 0.979742 100 100 1.22631 1.94885 29.299 1.00371 100 500 6.60746 0 157.865 0.956657 ``` **pmap_shard_outputs** ``` nouts nshards mean %std relative mean/baseline ------- --------- ---------- --------- ---------- --------------- 10 8 0.0664526 9.49251 1 0.938466 100 8 0.195711 2.19429 2.94512 1.04239 500 8 0.82577 0.330864 12.4265 0.994669 1000 8 1.68323 1.0516 25.3298 0.966915 5000 8 8.89032 0 133.784 0.998038 100 2 0.074806 10.1734 1.12571 0.980254 100 4 0.121334 5.76774 1.82588 1.02033 100 8 0.185253 5.45068 2.78775 1.01666 100 100 2.37076 0 35.6759 1.08629 100 500 17.0832 0 257.074 0.976879 ``` **ShardedDeviceArray_indexing** ``` indices_fn mean %std relative mean/baseline ------------------ ---------- ------- ---------- --------------- integer_indices 0.0603473 8.29159 1 0.359496 integer_2D_indices 18.0241 0 298.672 1.00583 ``` This is how I ran the benchmark: ``` TARGET_TOTAL_SECS=2 CUDA_VISIBLE_DEVICES= XLA_FLAGS=--xla_force_host_platform_device_count=500 python3 benchmarks/pmap_benchmark.py --baseline_dir=<results as of a3cc9a7> ```
2020-04-15 12:43:55 -07:00
self.assertIsInstance(x, pxla.ShardedDeviceArray)
x._npy_value = np.float32(np.nan) # can't be coerced to ndarray for xfer
2019-07-06 10:00:08 -07:00
x = soft_pmap(lambda x: x)(x) # doesn't crash
Allow ShardedDeviceArrays to represent arbitrary data shardings. (#2142) This change introduces ShardingSpec, a struct describing how an array should be sharded. This is integrated into ShardedDeviceArray to allow more flexible sharding. It supports partitioning (both "pmap-style", where an entire axis is decomposed into separate shards and doesn't appear in the on-device shape at all, and "sharded_jit-style", where an axis is chunked into shards but remains in the on-device shape) and replication. This removes the need for ChunkedDeviceArray, since a ShardedDeviceArray can now represent chunks. Here are pmap_benchmark times showing that the overall effect of this change neutral to positive (integer indexing is much faster!). **pmap_shard_args** ``` ---------Benchmark summary for pmap_shard_args--------- nargs nshards mean %std relative mean/baseline ------- --------- --------- --------- ---------- --------------- 10 8 0.041855 4.15223 1 1.01466 100 8 0.129884 4.85321 3.1032 0.988543 101 8 0.136347 6.20233 3.2576 0.967138 500 8 0.533207 3.6815 12.7394 1.0294 1000 8 1.10338 0.525193 26.362 0.960435 5000 8 5.33911 0 127.562 0.963319 100 2 0.0638619 10.7069 1.52579 1.0362 100 4 0.0868253 6.76701 2.07443 0.967323 100 8 0.128151 6.46004 3.06177 0.979742 100 100 1.22631 1.94885 29.299 1.00371 100 500 6.60746 0 157.865 0.956657 ``` **pmap_shard_outputs** ``` nouts nshards mean %std relative mean/baseline ------- --------- ---------- --------- ---------- --------------- 10 8 0.0664526 9.49251 1 0.938466 100 8 0.195711 2.19429 2.94512 1.04239 500 8 0.82577 0.330864 12.4265 0.994669 1000 8 1.68323 1.0516 25.3298 0.966915 5000 8 8.89032 0 133.784 0.998038 100 2 0.074806 10.1734 1.12571 0.980254 100 4 0.121334 5.76774 1.82588 1.02033 100 8 0.185253 5.45068 2.78775 1.01666 100 100 2.37076 0 35.6759 1.08629 100 500 17.0832 0 257.074 0.976879 ``` **ShardedDeviceArray_indexing** ``` indices_fn mean %std relative mean/baseline ------------------ ---------- ------- ---------- --------------- integer_indices 0.0603473 8.29159 1 0.359496 integer_2D_indices 18.0241 0 298.672 1.00583 ``` This is how I ran the benchmark: ``` TARGET_TOTAL_SECS=2 CUDA_VISIBLE_DEVICES= XLA_FLAGS=--xla_force_host_platform_device_count=500 python3 benchmarks/pmap_benchmark.py --baseline_dir=<results as of a3cc9a7> ```
2020-04-15 12:43:55 -07:00
self.assertIsInstance(x, pxla.ShardedDeviceArray)
2019-07-06 10:00:08 -07:00
@unittest.skip("the underlying code here is broken") # TODO(mattjj)
handle mapped_invars correctly in more places (#2828) fixes #2822 We didn't handle `pmap`'s `mapped_invars` correctly in all places in #1959. (I'm actually not sure if #1959 introduced the bug where things were working before, or just refactored it in terms of `mapped_invars`, though my guess is that because the information now contained in `mapped_invars` was implicitly contained in the pmapped jaxpr's `constvars` and `env_vars` that it was working correctly before #1959.) In particular, in #1959 we: 1. assumed the `mapped_invars` parameter of xla_pmap_p was only populated after partial_eval and set to None otherwise (i.e. staging out for a jit or a control flow primitive), 2. didn't update it correctly in JVPTrace.process_map (which adds new inputs corresponding to nonzero tangents, and hence `mapped_invars` must be grown), 3. didn't update it correctly in JaxprTrace.process_map (which adds residual inputs to the staged-out version of the primitive), 4. didn't forward it correctly in JaxprTrace.process_map anyway (we were setting it to all-true for the staged out eqn for all tracers regardless of what the original `mapped_invars` said), 5. removed the leading axes of all pvs in JaxprTrace.process_map regardless of whether the corresponding entry of `mapped_invars` was True or False. The reason we didn't notice 2 and 3 was that they only arise when doing control flow (e.g. scan or remat) of pmap involving closed-over tracers (apparently a rare case), since that's the case where we first form a jaxpr (populating `mapped_invars`) and then later have to apply transformations like AD and further partial eval (thus engaging JVPTrace.process_map and JaxprTrace.process_map with a populated `mapped_invars` parameter). It worked in other cases, e.g. when the pmap was not inside control flow or a remat, because in those cases we left `mapped_invars` set to None, indicating all-true of any length (so it didn't matter if we add inputs). This commit fixes those issues by 1. making `mapped_invars` non-optional, 2. handling `mapped_invars` correctly in * JaxprTrace.process_map * JVPTrace.process_map * ad.map_transpose (since having symbolic-zero cotangents effectively prunes inputs, and having undefined-primal args also prunes inputs) * ad._eval_subjaxpr_primals (since having undefined-primal args prunes inputs) 3. making the separate cases of calls and maps handled more explicitly by adding a new Primitive.map_primitive boolean attribute (analogous to Primitive.call_primitive), to be revised further in #2829. This is begging for a more coherent cleanup. For example, we reuse the same Primitive class but tag it with `call_primitive` or `map_primitive` (only one of which can be True); we should instead just have a separate Primitive class for these cases and track the type tag with built-in Python mechanisms. Moreover, when `call_primitive=True` or `map_primitive=True` implies things about what `params` must be present (`call_jaxpr` and `mapped_invars`). I plan to follow up with those cleanups in #2829, but I wanted to get something working first.
2020-04-24 18:45:34 -07:00
def testSoftPmapAllToAll(self):
n = 4 * jax.device_count()
2019-06-23 16:41:59 -07:00
def f(x):
return lax.all_to_all(x, 'i', 0, 0)
ans = soft_pmap(f, 'i')(jnp.arange(n ** 2).reshape(n, n))
expected = np.arange(n ** 2).reshape(n, n).T
2019-06-23 16:41:59 -07:00
self.assertAllClose(ans, expected, check_dtypes=False)
def testShardedDeviceArrayBlockUntilReady(self):
x = np.arange(jax.device_count())
x = self.pmap(lambda x: x)(x)
x.block_until_ready() # doesn't crash
@ignore_jit_of_pmap_warning()
enable jit+pmap by merging pxla.py and xla.py This change is essentially de-duplicating the XLA lowering logic between xla.py and pxla.py. Only the latter was capable of handling collectives (aka pmap primitives), which meant that these didn't work: 1. some compositions of jit and pmap, like jit-of-pmap 2. collectives inside initial-style control flow like scan 3. jax.xla_computation on a function involving collectives By merging the logic into xla.py, now all the lowering machinery works with everything. Woo! The pxla.py file still exists and contains mostly dynamic/runtime components for pmap and functions used only by pmap and collectives translations. In particular, pxla.py has * the pmap impl, particularly the dispatching logic for top-level pmaps, including argument sharding and lazy sharded result persistence * the ShardedDeviceArray / ShardedDeviceTuple classes * the dynamic (trace-time) axis environment data structures and logic and the special axis_index primitive * the split-axis transformation for soft_pmap * the PmapPrimitive (just a tagged version of Primitive) * the static sharding/unsharding logic for pmap-inside-jit/pmap These things moved over to xla.py * the logic for lowering pmap primitives, especially the static axis environment used during xla lowering This change refactors the translation rule tables a bit. Instead of just having one table, there are now four, and they contain rules with slightly different type signatures: * the `translations` table has rules with the same signatures as always, i.e. `CompBuilder -> [XlaOperands] -> ParamsDict -> XlaOperandOut` * the `backend_specific_translations` table is keyed by platform name strings and has dict values that each have the same type as `translations` * the `parallel_translations` table is used for primitives modeling parallel collectives, and so it has rules with signature `CompBuilder -> [XlaOperands] -> ReplicaGroups -> ParamsDict -> XlaOpOut` * the `initial_style_translations` table is for the initial-style control flow primitives (like `scan`), for which the translation rules themselves lower jaxprs to XLA computations and thus require the static axis env to be passed in; the rules there have signature `CompBuilder -> AxisEnv -> [XlaOperands] -> ParamsDict -> XlaOpOut` * the `call_translations` table is sued for `xla_call` and `xla_pmap`, i.e. the primitives underlying `jit` and `pmap` respectively, and has rules with signature `CompBuilder -> Jaxpr -> AxisEnv -> [XlaOp] -> [XlaOp] -> ParamsDict -> XlaOp` Having these as separate tables is an uninteresting implementation detail. The lowering function `_jaxpr_computation` just does a case analysis on whether the primitive being translated has an entry in any table (where the `backend_specific_translations` table must be checked before the `translations` table, since some primitives may be entered in both). This change fixes #804 also addresses #852, in that the lax control flow impls for those primitives are now based on Python-level jaxpr interpreters rather than XLA compilation, but we should probably wait to close the latter issue until we benchmark and improve things more. This change at least seems not to be a performance regression: on my machine the lax control flow tests go from running in ~20s to running in ~14s. This change also adds a docstring for `jax.xla_computation` and some basic tests.
2019-07-02 13:17:31 -07:00
def testJitPmapComposition(self):
f = lambda x: x - lax.psum(x, 'i')
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
expected = x - np.sum(x, 0)
enable jit+pmap by merging pxla.py and xla.py This change is essentially de-duplicating the XLA lowering logic between xla.py and pxla.py. Only the latter was capable of handling collectives (aka pmap primitives), which meant that these didn't work: 1. some compositions of jit and pmap, like jit-of-pmap 2. collectives inside initial-style control flow like scan 3. jax.xla_computation on a function involving collectives By merging the logic into xla.py, now all the lowering machinery works with everything. Woo! The pxla.py file still exists and contains mostly dynamic/runtime components for pmap and functions used only by pmap and collectives translations. In particular, pxla.py has * the pmap impl, particularly the dispatching logic for top-level pmaps, including argument sharding and lazy sharded result persistence * the ShardedDeviceArray / ShardedDeviceTuple classes * the dynamic (trace-time) axis environment data structures and logic and the special axis_index primitive * the split-axis transformation for soft_pmap * the PmapPrimitive (just a tagged version of Primitive) * the static sharding/unsharding logic for pmap-inside-jit/pmap These things moved over to xla.py * the logic for lowering pmap primitives, especially the static axis environment used during xla lowering This change refactors the translation rule tables a bit. Instead of just having one table, there are now four, and they contain rules with slightly different type signatures: * the `translations` table has rules with the same signatures as always, i.e. `CompBuilder -> [XlaOperands] -> ParamsDict -> XlaOperandOut` * the `backend_specific_translations` table is keyed by platform name strings and has dict values that each have the same type as `translations` * the `parallel_translations` table is used for primitives modeling parallel collectives, and so it has rules with signature `CompBuilder -> [XlaOperands] -> ReplicaGroups -> ParamsDict -> XlaOpOut` * the `initial_style_translations` table is for the initial-style control flow primitives (like `scan`), for which the translation rules themselves lower jaxprs to XLA computations and thus require the static axis env to be passed in; the rules there have signature `CompBuilder -> AxisEnv -> [XlaOperands] -> ParamsDict -> XlaOpOut` * the `call_translations` table is sued for `xla_call` and `xla_pmap`, i.e. the primitives underlying `jit` and `pmap` respectively, and has rules with signature `CompBuilder -> Jaxpr -> AxisEnv -> [XlaOp] -> [XlaOp] -> ParamsDict -> XlaOp` Having these as separate tables is an uninteresting implementation detail. The lowering function `_jaxpr_computation` just does a case analysis on whether the primitive being translated has an entry in any table (where the `backend_specific_translations` table must be checked before the `translations` table, since some primitives may be entered in both). This change fixes #804 also addresses #852, in that the lax control flow impls for those primitives are now based on Python-level jaxpr interpreters rather than XLA compilation, but we should probably wait to close the latter issue until we benchmark and improve things more. This change at least seems not to be a performance regression: on my machine the lax control flow tests go from running in ~20s to running in ~14s. This change also adds a docstring for `jax.xla_computation` and some basic tests.
2019-07-02 13:17:31 -07:00
ans = jit(self.pmap(f, 'i'))(x)
enable jit+pmap by merging pxla.py and xla.py This change is essentially de-duplicating the XLA lowering logic between xla.py and pxla.py. Only the latter was capable of handling collectives (aka pmap primitives), which meant that these didn't work: 1. some compositions of jit and pmap, like jit-of-pmap 2. collectives inside initial-style control flow like scan 3. jax.xla_computation on a function involving collectives By merging the logic into xla.py, now all the lowering machinery works with everything. Woo! The pxla.py file still exists and contains mostly dynamic/runtime components for pmap and functions used only by pmap and collectives translations. In particular, pxla.py has * the pmap impl, particularly the dispatching logic for top-level pmaps, including argument sharding and lazy sharded result persistence * the ShardedDeviceArray / ShardedDeviceTuple classes * the dynamic (trace-time) axis environment data structures and logic and the special axis_index primitive * the split-axis transformation for soft_pmap * the PmapPrimitive (just a tagged version of Primitive) * the static sharding/unsharding logic for pmap-inside-jit/pmap These things moved over to xla.py * the logic for lowering pmap primitives, especially the static axis environment used during xla lowering This change refactors the translation rule tables a bit. Instead of just having one table, there are now four, and they contain rules with slightly different type signatures: * the `translations` table has rules with the same signatures as always, i.e. `CompBuilder -> [XlaOperands] -> ParamsDict -> XlaOperandOut` * the `backend_specific_translations` table is keyed by platform name strings and has dict values that each have the same type as `translations` * the `parallel_translations` table is used for primitives modeling parallel collectives, and so it has rules with signature `CompBuilder -> [XlaOperands] -> ReplicaGroups -> ParamsDict -> XlaOpOut` * the `initial_style_translations` table is for the initial-style control flow primitives (like `scan`), for which the translation rules themselves lower jaxprs to XLA computations and thus require the static axis env to be passed in; the rules there have signature `CompBuilder -> AxisEnv -> [XlaOperands] -> ParamsDict -> XlaOpOut` * the `call_translations` table is sued for `xla_call` and `xla_pmap`, i.e. the primitives underlying `jit` and `pmap` respectively, and has rules with signature `CompBuilder -> Jaxpr -> AxisEnv -> [XlaOp] -> [XlaOp] -> ParamsDict -> XlaOp` Having these as separate tables is an uninteresting implementation detail. The lowering function `_jaxpr_computation` just does a case analysis on whether the primitive being translated has an entry in any table (where the `backend_specific_translations` table must be checked before the `translations` table, since some primitives may be entered in both). This change fixes #804 also addresses #852, in that the lax control flow impls for those primitives are now based on Python-level jaxpr interpreters rather than XLA compilation, but we should probably wait to close the latter issue until we benchmark and improve things more. This change at least seems not to be a performance regression: on my machine the lax control flow tests go from running in ~20s to running in ~14s. This change also adds a docstring for `jax.xla_computation` and some basic tests.
2019-07-02 13:17:31 -07:00
self.assertAllClose(ans, expected, check_dtypes=False)
ans = self.pmap(jit(f), 'i')(x)
enable jit+pmap by merging pxla.py and xla.py This change is essentially de-duplicating the XLA lowering logic between xla.py and pxla.py. Only the latter was capable of handling collectives (aka pmap primitives), which meant that these didn't work: 1. some compositions of jit and pmap, like jit-of-pmap 2. collectives inside initial-style control flow like scan 3. jax.xla_computation on a function involving collectives By merging the logic into xla.py, now all the lowering machinery works with everything. Woo! The pxla.py file still exists and contains mostly dynamic/runtime components for pmap and functions used only by pmap and collectives translations. In particular, pxla.py has * the pmap impl, particularly the dispatching logic for top-level pmaps, including argument sharding and lazy sharded result persistence * the ShardedDeviceArray / ShardedDeviceTuple classes * the dynamic (trace-time) axis environment data structures and logic and the special axis_index primitive * the split-axis transformation for soft_pmap * the PmapPrimitive (just a tagged version of Primitive) * the static sharding/unsharding logic for pmap-inside-jit/pmap These things moved over to xla.py * the logic for lowering pmap primitives, especially the static axis environment used during xla lowering This change refactors the translation rule tables a bit. Instead of just having one table, there are now four, and they contain rules with slightly different type signatures: * the `translations` table has rules with the same signatures as always, i.e. `CompBuilder -> [XlaOperands] -> ParamsDict -> XlaOperandOut` * the `backend_specific_translations` table is keyed by platform name strings and has dict values that each have the same type as `translations` * the `parallel_translations` table is used for primitives modeling parallel collectives, and so it has rules with signature `CompBuilder -> [XlaOperands] -> ReplicaGroups -> ParamsDict -> XlaOpOut` * the `initial_style_translations` table is for the initial-style control flow primitives (like `scan`), for which the translation rules themselves lower jaxprs to XLA computations and thus require the static axis env to be passed in; the rules there have signature `CompBuilder -> AxisEnv -> [XlaOperands] -> ParamsDict -> XlaOpOut` * the `call_translations` table is sued for `xla_call` and `xla_pmap`, i.e. the primitives underlying `jit` and `pmap` respectively, and has rules with signature `CompBuilder -> Jaxpr -> AxisEnv -> [XlaOp] -> [XlaOp] -> ParamsDict -> XlaOp` Having these as separate tables is an uninteresting implementation detail. The lowering function `_jaxpr_computation` just does a case analysis on whether the primitive being translated has an entry in any table (where the `backend_specific_translations` table must be checked before the `translations` table, since some primitives may be entered in both). This change fixes #804 also addresses #852, in that the lax control flow impls for those primitives are now based on Python-level jaxpr interpreters rather than XLA compilation, but we should probably wait to close the latter issue until we benchmark and improve things more. This change at least seems not to be a performance regression: on my machine the lax control flow tests go from running in ~20s to running in ~14s. This change also adds a docstring for `jax.xla_computation` and some basic tests.
2019-07-02 13:17:31 -07:00
self.assertAllClose(ans, expected, check_dtypes=False)
def testCompositionWithJitTwice(self):
@jit
def f(x):
y = 2 * x
@jit
def g(z):
return self.pmap(lambda x: x[jnp.newaxis] * y)(z)
return g(x)
f(np.arange(1.).reshape((1, 1))) # doesn't crash
@ignore_jit_of_pmap_warning()
def testIssue1065(self):
# from https://github.com/google/jax/issues/1065
device_count = jax.device_count()
def multi_step_pmap(state, count):
@partial(self.pmap, axis_name='x')
@jit
def exchange_and_multi_step(state):
return state
@jit
def time_evolution(state):
return lax.fori_loop(0, count, lambda i, s: exchange_and_multi_step(s), state)
return time_evolution(state)
multi_step_pmap(jnp.zeros((device_count,)), count=1)
2019-08-21 16:39:59 -07:00
def testShardedDeviceArrayGetItem(self):
f = lambda x: 2 * x
f = self.pmap(f, axis_name='i')
2019-08-21 16:39:59 -07:00
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
2019-08-21 16:39:59 -07:00
y = f(x)
self.assertIsInstance(y, jnp.ndarray)
2019-08-21 16:39:59 -07:00
self.assertIsInstance(y, pxla.ShardedDeviceArray)
z = y[0] # doesn't crash
self.assertAllClose(z, 2 * x[0], check_dtypes=False)
2019-08-21 16:39:59 -07:00
# TODO(mattjj): this fails with multiple devices (unless we add a jit)
# because we assume eager ops (like scan here) can't require more than 1
# replica.
@unittest.skip("need eager multi-replica support")
def testPostProcessMap(self):
# test came from https://github.com/google/jax/issues/1369
nrep = jax.device_count()
def pmvm(a, b):
a = a.reshape((nrep, -1, a.shape[1]))
func = self.pmap(lambda z: jnp.dot(z, b))
return func(a).reshape(b.shape)
2019-09-20 20:45:01 -07:00
n = nrep * 2
rng = self.rng()
2019-09-20 20:45:01 -07:00
a = rng.randn(n, n)
b = rng.randn(n)
iters = jnp.arange(5)
def body(carry, i):
return pmvm(a, carry), i
ans, _ = lax.scan(body, b, iters)
expected = np.linalg.matrix_power(a, 5).dot(b)
self.assertAllClose(ans, expected, check_dtypes=False)
def testManyArgs(self):
@self.pmap
def f(args_list):
return sum(args_list)
vals = list(range(500))
ndevices = jax.device_count()
self.assertAllClose(f(jnp.array([vals] * ndevices)),
jnp.array([sum(vals)] * ndevices))
2020-06-01 15:28:57 -07:00
def testPostProcessMap2(self):
# code from https://github.com/google/jax/issues/2787
def vv(x, y):
"""Vector-vector multiply"""
return jnp.dot(x, y)
def distributed_matrix_vector(x, y):
"""Matrix vector multiply. First batch it and then row by row"""
fv = lambda z: lax.map(lambda j: vv(j, y), z)
res = self.pmap(fv)(x.reshape((jax.device_count(), -1) + tuple(x.shape[1:])))
res = res.reshape(res.shape[0] * res.shape[1], *res.shape[2:])
return res
key = random.PRNGKey(1)
2020-04-21 18:27:53 -07:00
x = random.normal(key, (80, 50))
batched_mvm = vmap(lambda b: distributed_matrix_vector(x, b), in_axes=0)
y = random.normal(key, (10, 50, 1))
result = batched_mvm(y)
expected = jnp.einsum('ij,njk->nik', x, y)
2020-04-21 18:27:53 -07:00
tol = 1e-1 if jtu.device_under_test() == "tpu" else 1e-3
self.assertAllClose(result, expected, check_dtypes=False, atol=tol, rtol=tol)
def testAxisIndexRemat(self):
# https://github.com/google/jax/issues/2716
n = len(jax.devices())
def f(key):
key = random.fold_in(key, jax.lax.axis_index('i'))
return random.bernoulli(key, p=0.5)
keys = random.split(random.PRNGKey(0), n)
self.pmap(jax.remat(f), axis_name='i')(keys)
handle mapped_invars correctly in more places (#2828) fixes #2822 We didn't handle `pmap`'s `mapped_invars` correctly in all places in #1959. (I'm actually not sure if #1959 introduced the bug where things were working before, or just refactored it in terms of `mapped_invars`, though my guess is that because the information now contained in `mapped_invars` was implicitly contained in the pmapped jaxpr's `constvars` and `env_vars` that it was working correctly before #1959.) In particular, in #1959 we: 1. assumed the `mapped_invars` parameter of xla_pmap_p was only populated after partial_eval and set to None otherwise (i.e. staging out for a jit or a control flow primitive), 2. didn't update it correctly in JVPTrace.process_map (which adds new inputs corresponding to nonzero tangents, and hence `mapped_invars` must be grown), 3. didn't update it correctly in JaxprTrace.process_map (which adds residual inputs to the staged-out version of the primitive), 4. didn't forward it correctly in JaxprTrace.process_map anyway (we were setting it to all-true for the staged out eqn for all tracers regardless of what the original `mapped_invars` said), 5. removed the leading axes of all pvs in JaxprTrace.process_map regardless of whether the corresponding entry of `mapped_invars` was True or False. The reason we didn't notice 2 and 3 was that they only arise when doing control flow (e.g. scan or remat) of pmap involving closed-over tracers (apparently a rare case), since that's the case where we first form a jaxpr (populating `mapped_invars`) and then later have to apply transformations like AD and further partial eval (thus engaging JVPTrace.process_map and JaxprTrace.process_map with a populated `mapped_invars` parameter). It worked in other cases, e.g. when the pmap was not inside control flow or a remat, because in those cases we left `mapped_invars` set to None, indicating all-true of any length (so it didn't matter if we add inputs). This commit fixes those issues by 1. making `mapped_invars` non-optional, 2. handling `mapped_invars` correctly in * JaxprTrace.process_map * JVPTrace.process_map * ad.map_transpose (since having symbolic-zero cotangents effectively prunes inputs, and having undefined-primal args also prunes inputs) * ad._eval_subjaxpr_primals (since having undefined-primal args prunes inputs) 3. making the separate cases of calls and maps handled more explicitly by adding a new Primitive.map_primitive boolean attribute (analogous to Primitive.call_primitive), to be revised further in #2829. This is begging for a more coherent cleanup. For example, we reuse the same Primitive class but tag it with `call_primitive` or `map_primitive` (only one of which can be True); we should instead just have a separate Primitive class for these cases and track the type tag with built-in Python mechanisms. Moreover, when `call_primitive=True` or `map_primitive=True` implies things about what `params` must be present (`call_jaxpr` and `mapped_invars`). I plan to follow up with those cleanups in #2829, but I wanted to get something working first.
2020-04-24 18:45:34 -07:00
def testPmapMapVmapCombinations(self):
# https://github.com/google/jax/issues/2822
def vv(x, y):
"""Vector-vector multiply"""
return jnp.dot(x, y)
handle mapped_invars correctly in more places (#2828) fixes #2822 We didn't handle `pmap`'s `mapped_invars` correctly in all places in #1959. (I'm actually not sure if #1959 introduced the bug where things were working before, or just refactored it in terms of `mapped_invars`, though my guess is that because the information now contained in `mapped_invars` was implicitly contained in the pmapped jaxpr's `constvars` and `env_vars` that it was working correctly before #1959.) In particular, in #1959 we: 1. assumed the `mapped_invars` parameter of xla_pmap_p was only populated after partial_eval and set to None otherwise (i.e. staging out for a jit or a control flow primitive), 2. didn't update it correctly in JVPTrace.process_map (which adds new inputs corresponding to nonzero tangents, and hence `mapped_invars` must be grown), 3. didn't update it correctly in JaxprTrace.process_map (which adds residual inputs to the staged-out version of the primitive), 4. didn't forward it correctly in JaxprTrace.process_map anyway (we were setting it to all-true for the staged out eqn for all tracers regardless of what the original `mapped_invars` said), 5. removed the leading axes of all pvs in JaxprTrace.process_map regardless of whether the corresponding entry of `mapped_invars` was True or False. The reason we didn't notice 2 and 3 was that they only arise when doing control flow (e.g. scan or remat) of pmap involving closed-over tracers (apparently a rare case), since that's the case where we first form a jaxpr (populating `mapped_invars`) and then later have to apply transformations like AD and further partial eval (thus engaging JVPTrace.process_map and JaxprTrace.process_map with a populated `mapped_invars` parameter). It worked in other cases, e.g. when the pmap was not inside control flow or a remat, because in those cases we left `mapped_invars` set to None, indicating all-true of any length (so it didn't matter if we add inputs). This commit fixes those issues by 1. making `mapped_invars` non-optional, 2. handling `mapped_invars` correctly in * JaxprTrace.process_map * JVPTrace.process_map * ad.map_transpose (since having symbolic-zero cotangents effectively prunes inputs, and having undefined-primal args also prunes inputs) * ad._eval_subjaxpr_primals (since having undefined-primal args prunes inputs) 3. making the separate cases of calls and maps handled more explicitly by adding a new Primitive.map_primitive boolean attribute (analogous to Primitive.call_primitive), to be revised further in #2829. This is begging for a more coherent cleanup. For example, we reuse the same Primitive class but tag it with `call_primitive` or `map_primitive` (only one of which can be True); we should instead just have a separate Primitive class for these cases and track the type tag with built-in Python mechanisms. Moreover, when `call_primitive=True` or `map_primitive=True` implies things about what `params` must be present (`call_jaxpr` and `mapped_invars`). I plan to follow up with those cleanups in #2829, but I wanted to get something working first.
2020-04-24 18:45:34 -07:00
def matrix_vector(x, y, parallel=True):
"""Matrix vector multiply. First batch it and then row by row"""
fv = lambda z: lax.map(lambda j: vv(j, y), z)
if parallel:
# split leading axis in two
new_x = x.reshape((jax.device_count(), -1, *x.shape[1:]))
# apply map
new_res = self.pmap(fv)(new_x)
handle mapped_invars correctly in more places (#2828) fixes #2822 We didn't handle `pmap`'s `mapped_invars` correctly in all places in #1959. (I'm actually not sure if #1959 introduced the bug where things were working before, or just refactored it in terms of `mapped_invars`, though my guess is that because the information now contained in `mapped_invars` was implicitly contained in the pmapped jaxpr's `constvars` and `env_vars` that it was working correctly before #1959.) In particular, in #1959 we: 1. assumed the `mapped_invars` parameter of xla_pmap_p was only populated after partial_eval and set to None otherwise (i.e. staging out for a jit or a control flow primitive), 2. didn't update it correctly in JVPTrace.process_map (which adds new inputs corresponding to nonzero tangents, and hence `mapped_invars` must be grown), 3. didn't update it correctly in JaxprTrace.process_map (which adds residual inputs to the staged-out version of the primitive), 4. didn't forward it correctly in JaxprTrace.process_map anyway (we were setting it to all-true for the staged out eqn for all tracers regardless of what the original `mapped_invars` said), 5. removed the leading axes of all pvs in JaxprTrace.process_map regardless of whether the corresponding entry of `mapped_invars` was True or False. The reason we didn't notice 2 and 3 was that they only arise when doing control flow (e.g. scan or remat) of pmap involving closed-over tracers (apparently a rare case), since that's the case where we first form a jaxpr (populating `mapped_invars`) and then later have to apply transformations like AD and further partial eval (thus engaging JVPTrace.process_map and JaxprTrace.process_map with a populated `mapped_invars` parameter). It worked in other cases, e.g. when the pmap was not inside control flow or a remat, because in those cases we left `mapped_invars` set to None, indicating all-true of any length (so it didn't matter if we add inputs). This commit fixes those issues by 1. making `mapped_invars` non-optional, 2. handling `mapped_invars` correctly in * JaxprTrace.process_map * JVPTrace.process_map * ad.map_transpose (since having symbolic-zero cotangents effectively prunes inputs, and having undefined-primal args also prunes inputs) * ad._eval_subjaxpr_primals (since having undefined-primal args prunes inputs) 3. making the separate cases of calls and maps handled more explicitly by adding a new Primitive.map_primitive boolean attribute (analogous to Primitive.call_primitive), to be revised further in #2829. This is begging for a more coherent cleanup. For example, we reuse the same Primitive class but tag it with `call_primitive` or `map_primitive` (only one of which can be True); we should instead just have a separate Primitive class for these cases and track the type tag with built-in Python mechanisms. Moreover, when `call_primitive=True` or `map_primitive=True` implies things about what `params` must be present (`call_jaxpr` and `mapped_invars`). I plan to follow up with those cleanups in #2829, but I wanted to get something working first.
2020-04-24 18:45:34 -07:00
# reshape back out
res = new_res.reshape(x.shape[0], *new_res.shape[2:])
else:
res = fv(x)
return res
x = random.normal(random.PRNGKey(1), (80, 5))
y = random.normal(random.PRNGKey(1), (10, 5))
result1 = vmap(lambda b: matrix_vector(x, b, True))(y) # vmap + pmap
result2 = lax.map(lambda b: matrix_vector(x, b, False), y) # map + map
with ignore_jit_of_pmap_warning():
result3 = lax.map(lambda b: matrix_vector(x, b, True), y) # map + pmap
result4 = jnp.stack([matrix_vector(x, b, False) for b in y]) # none + map
handle mapped_invars correctly in more places (#2828) fixes #2822 We didn't handle `pmap`'s `mapped_invars` correctly in all places in #1959. (I'm actually not sure if #1959 introduced the bug where things were working before, or just refactored it in terms of `mapped_invars`, though my guess is that because the information now contained in `mapped_invars` was implicitly contained in the pmapped jaxpr's `constvars` and `env_vars` that it was working correctly before #1959.) In particular, in #1959 we: 1. assumed the `mapped_invars` parameter of xla_pmap_p was only populated after partial_eval and set to None otherwise (i.e. staging out for a jit or a control flow primitive), 2. didn't update it correctly in JVPTrace.process_map (which adds new inputs corresponding to nonzero tangents, and hence `mapped_invars` must be grown), 3. didn't update it correctly in JaxprTrace.process_map (which adds residual inputs to the staged-out version of the primitive), 4. didn't forward it correctly in JaxprTrace.process_map anyway (we were setting it to all-true for the staged out eqn for all tracers regardless of what the original `mapped_invars` said), 5. removed the leading axes of all pvs in JaxprTrace.process_map regardless of whether the corresponding entry of `mapped_invars` was True or False. The reason we didn't notice 2 and 3 was that they only arise when doing control flow (e.g. scan or remat) of pmap involving closed-over tracers (apparently a rare case), since that's the case where we first form a jaxpr (populating `mapped_invars`) and then later have to apply transformations like AD and further partial eval (thus engaging JVPTrace.process_map and JaxprTrace.process_map with a populated `mapped_invars` parameter). It worked in other cases, e.g. when the pmap was not inside control flow or a remat, because in those cases we left `mapped_invars` set to None, indicating all-true of any length (so it didn't matter if we add inputs). This commit fixes those issues by 1. making `mapped_invars` non-optional, 2. handling `mapped_invars` correctly in * JaxprTrace.process_map * JVPTrace.process_map * ad.map_transpose (since having symbolic-zero cotangents effectively prunes inputs, and having undefined-primal args also prunes inputs) * ad._eval_subjaxpr_primals (since having undefined-primal args prunes inputs) 3. making the separate cases of calls and maps handled more explicitly by adding a new Primitive.map_primitive boolean attribute (analogous to Primitive.call_primitive), to be revised further in #2829. This is begging for a more coherent cleanup. For example, we reuse the same Primitive class but tag it with `call_primitive` or `map_primitive` (only one of which can be True); we should instead just have a separate Primitive class for these cases and track the type tag with built-in Python mechanisms. Moreover, when `call_primitive=True` or `map_primitive=True` implies things about what `params` must be present (`call_jaxpr` and `mapped_invars`). I plan to follow up with those cleanups in #2829, but I wanted to get something working first.
2020-04-24 18:45:34 -07:00
self.assertAllClose(result1, result2, check_dtypes=False, atol=1e-3, rtol=1e-3)
self.assertAllClose(result1, result3, check_dtypes=False, atol=1e-3, rtol=1e-3)
self.assertAllClose(result1, result4, check_dtypes=False, atol=1e-3, rtol=1e-3)
def testPmapAxisNameError(self):
# https://github.com/google/jax/issues/3120
a = np.arange(4)[np.newaxis,:]
def test(x):
return jax.lax.psum(x, axis_name='batch')
with self.assertRaisesRegex(NameError, "unbound axis name: batch"):
self.pmap(test)(a)
def testPsumOnBooleanDtype(self):
# https://github.com/google/jax/issues/3123
n = jax.device_count()
if n > 1:
x = jnp.array([True, False])
out = self.pmap(lambda x: jax.lax.psum(x, 'i'), 'i')(x)
self.assertEqual(list(out), [1, 1])
out = self.pmap(lambda x: jax.lax.pmean(x, 'i'), 'i')(x)
self.assertEqual(list(out), [1/2, 1/2])
else:
x = jnp.array([True])
out = self.pmap(lambda x: jax.lax.psum(x, 'i'), 'i')(x)
self.assertEqual(list(out), [1])
out = self.pmap(lambda x: jax.lax.pmean(x, 'i'), 'i')(x)
self.assertEqual(list(out), [1])
def testPsumWithNoAxisDoesntLeakFunctions(self):
x = jnp.ones((1, 1024), dtype=np.float32)
f = lambda _: x
w = weakref.ref(f)
g = self.pmap(f)
g(np.ones((1,), dtype=np.float32)).block_until_ready()
del f, g
gc.collect()
# 'f' should not be alive at this point; in particular the pmap cache must
# not keep it alive.
self.assertTrue(w() is None)
2020-06-12 16:10:45 -07:00
def testJitOfPmapWarningMessage(self):
device_count = jax.device_count()
2020-06-12 16:10:45 -07:00
if device_count == 1:
raise SkipTest("test requires at least two devices")
def foo(x): return x
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
jit(self.pmap(foo))(jnp.arange(device_count))
2020-06-12 16:10:45 -07:00
self.assertGreaterEqual(len(w), 1)
self.assertIn("The jitted function foo includes a pmap",
str(w[-1].message))
def testPsumZeroCotangents(self):
# https://github.com/google/jax/issues/3651
def loss(params, meta_params):
(net, mpo) = params
return meta_params * mpo * net
def inner(meta_params, params):
grads = jax.grad(loss)(params, meta_params)
grads = lax.psum(grads, axis_name="i")
net_grads, mpo_grads = grads
net = params[0] + net_grads
mpo = params[1]
return mpo * net
def outer(params):
meta_params = jnp.array(4.0)
return jax.grad(inner)(meta_params, params)
params = (jnp.array([2.0]), jnp.array([3.0]))
self.pmap(outer, axis_name='i')(params) # doesn't crash
f = self.pmap(outer, axis_name='i')
jtu.check_grads(f, (params,), 2, ["fwd", "rev"], 1e-3, 1e-3)
@ignore_jit_of_pmap_warning()
def test_issue_1062(self):
# code from https://github.com/google/jax/issues/1062 @shoyer
# this tests, among other things, whether ShardedDeviceTuple constants work
device_count = jax.device_count()
@jit
def multi_step(state, count):
return lax.fori_loop(0, count, lambda i, s: s, state)
@jit
def multi_step_pmap(state, count=2):
@partial(self.pmap, axis_name='x')
def pmapped_multi_step(state):
return multi_step(state, count)
return pmapped_multi_step(state)
u = np.ones((device_count, 100))
multi_step_pmap(u) # doesn't crash
@jtu.skip_on_devices("cpu")
def test_replicate_backend(self):
# TODO(skye): fix backend caching so we always have multiple CPUs available
if jax.device_count("cpu") < 4:
self.skipTest("test requires 4 CPU device")
# https://github.com/google/jax/issues/4223
def fn(indices):
return jnp.equal(indices, jnp.arange(3)).astype(jnp.float32)
mapped_fn = self.pmap(fn, axis_name='i', backend='cpu')
mapped_fn = self.pmap(mapped_fn, axis_name='j', backend='cpu')
indices = np.array([[[2], [1]], [[0], [0]]])
mapped_fn(indices) # doesn't crash
@ignore_xmap_warning()
def testPdotBasic(self):
num_devices = jax.device_count()
def f(x, y):
return lax.pdot(x, y, 'i')
x = jnp.arange(num_devices * 3).reshape(num_devices, 3)
y = jnp.arange(num_devices * 5).reshape(num_devices, 5)
z = self.pmap(f, axis_name='i', out_axes=None)(x, y)
self.assertAllClose(z, jnp.dot(x.T, y))
2021-02-08 20:24:19 -08:00
@parameterized.named_parameters(
{"testcase_name": "_shape={}_axis={}_collective={}".format(
jtu.format_shape_dtype_string(shape, dtype),
axis, collective.__name__.replace(" ", "")),
"shape": shape, "dtype": dtype, "axis": axis,
"collective": collective, "bulk_op": bulk_op}
2021-02-11 08:30:37 -08:00
for collective, bulk_op in [
(parallel.pargmax, jnp.argmax),
(parallel.pargmin, jnp.argmin)
]
2021-02-08 20:24:19 -08:00
for dtype in [np.float32, np.int32]
2021-02-11 08:30:37 -08:00
for shape in [(4,), (2, 2), (2, 4), (4, 2)]
2021-02-08 20:24:19 -08:00
for axis in range(len(shape))
)
def testArgAllReduce(self, shape, dtype, axis, collective, bulk_op):
if jax.device_count() < shape[axis]:
2021-02-11 08:30:37 -08:00
raise SkipTest(f"test requires at least {shape[axis]} devices")
if (jtu.device_under_test() == 'cpu' and
np.issubdtype(dtype, np.floating) and
len(shape) > 1):
raise SkipTest("skipped on cpu due to strange failures") # TODO(mattjj)
2021-02-08 20:24:19 -08:00
rng = jtu.rand_default(self.rng())
x = rng(shape, dtype)
ans = self.pmap(lambda x: collective(x, 'i'), in_axes=axis, out_axes=None,
2021-02-08 20:24:19 -08:00
axis_name='i')(x)
expected = bulk_op(x, axis=axis)
self.assertAllClose(ans, expected, check_dtypes=False)
2021-03-25 12:43:31 -07:00
@parameterized.named_parameters(
{"testcase_name": "_dtype={}".format(
jtu.format_shape_dtype_string((), dtype)),
"dtype": dtype}
for dtype in [np.float32, np.int32]
)
def testPmapDtype(self, dtype):
# Regression test for https://github.com/google/jax/issues/6022
@partial(self.pmap, axis_name='i')
2021-03-25 12:43:31 -07:00
def func(_):
return jax.lax.psum(dtype(0), axis_name='i')
unused_arg = jnp.arange(jax.device_count())
2021-03-25 12:43:31 -07:00
out_dtype = func(unused_arg).dtype
self.assertEqual(out_dtype, dtype)
2021-07-29 10:34:43 -07:00
def test_num_replicas_with_switch(self):
# https://github.com/google/jax/issues/7411
def identity(x):
return x
def cond_of_pmap(x):
y = lax.cond(True, jax.pmap(identity), jax.pmap(identity), x)
return y
with ignore_jit_of_pmap_warning():
cond_of_pmap(jnp.zeros((jax.device_count(), 2)))
2021-07-29 10:34:43 -07:00
def test_static_argnum_on_method(self):
class A:
@partial(self.pmap, static_broadcasted_argnums=(0,))
def my_func_pmap(self, x):
return x + 2
A().my_func_pmap(jnp.asarray([3] * jax.device_count()))
def test_pmap_error_on_non_hashable_static_argument(self):
f = lambda x, y: x + 3
pmapped_f = self.pmap(f, static_broadcasted_argnums=(1,))
inputs = np.asarray([1] * jax.device_count())
with self.assertRaisesRegex(
ValueError, "Non-hashable static arguments are not supported.*"):
pmapped_f(inputs, np.asarray(1))
@parameterized.named_parameters(
{"testcase_name": f"_axis_size={axis_size}", "axis_size": axis_size}
for axis_size in [1, 2])
def test_grad_of_pmap_compilation_caching(self, axis_size):
if len(jax.local_devices()) < axis_size:
raise SkipTest("too few devices for test")
@jax.pmap
def f(x):
return jnp.sin(x)
x = jnp.ones(axis_size)
f(x) # warm-up any dispatching compilations
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
_, f_bwd = jax.vjp(f, x)
_ = f_bwd(x)
self.assertEqual(count[0], 2) # one for fwd, one for bwd
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
_ = jax.vjp(f, x)
_ = f_bwd(x)
self.assertEqual(count[0], 0) # cache hits on fwd and bwd
2020-06-12 16:10:45 -07:00
def testSizeOverflow(self):
x = jnp.arange(1)
x = self.pmap(lambda _: jnp.ones([8, 267736, 1024], dtype=jnp.int8))(x)
self.assertEqual(x.size, 8 * 267736 * 1024)
self.assertEqual(type(x.size), int)
class CppPmapTest(PythonPmapTest):
@property
def pmap(self):
return src_api._cpp_pmap
def pmap_fast_path_is_enabled(self):
num_devices = jax.device_count()
f = jax.pmap(lambda x: x+1)
size = f._cache_size()
f(np.zeros([num_devices], dtype=np.float32))
self.assertEqual(f._cache_size(), size+1)
class VmapOfPmapTest(jtu.JaxTestCase):
# TODO(apaszke)
@parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({
"testcase_name": f"{shapes}_{vmap_in_axes}_{vmap_out_axes}_{pmap_in_axes}_{pmap_out_axes}",
Add support for non-zero (but still not-None) out_axes in pmap Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`), but its semantics would match the specification of `out_axes=0` (i.e. all outputs should be stacked along the first axis). This patch makes it possible to specify non-zero values for out_axes, but more importantly it lays down the groundwork for `xmap` which will have to use some extremely similar (if not the same) code paths. One thing to note is that when I started this implementation I was also planning to add support for `out_axes=None`, which would allow us to stop using the `unbroadcast` hack, and most of the code is written with that in mind. Unfortunately it turned out that the correct implementation of the transpose rule for maps that do allow unmapped outputs would require me to pretty much simulate what avals-with-names is supposed to achieve. Technically replicated outputs should work today, for as long as the user does not do reverse-mode AD of `pmap`. But I decided that it's better to just disable them altogether until we can get the full and correct behavior. * Implementation details * This patch is significantly more involved than the one that implemented general `in_axes` support. That previous one at least had the foundation of `mapped_invars` which already behaved pretty similarly to general `in_axes`. From a quick glance one might think that `out_axes` should behave similarly to `in_axes`, but it turns out that this is not the case, at least not if we're interested in keeping those primitives final-style. ** Thunking ** The biggest difficulty with handling `out_axes` in final style primitives is that we want to treat them as a prefix of the output pytree, but we don't know the structure of the output pytree until the user function is evaluated! And the user function is not evaluated until we've applied all transforms and reached the impl rule! The solution to this problem is "straightforward": instead of putting `out_axes` as a primitive parameter, we bundle an `out_axes_thunk` which can only be called successfully after the wrapped function has been executed. The thunk returns a list of flat `out_axes`, expanded to the output pytree. However, the thunking presents us with two problems: *** Transformations *** Each transformation that modifies the number of outputs needs to ensure that the thunk is updated to reflect the new values. To make things worse a lot of the transforms can learn the number of added outputs _only after the wrapped function is evaluated_, which leads to the following "time travel" pattern that can be found in most `Trace`s: ```py @lu.transformation_with_aux def compute_output_statistic(*args, **kwargs): outputs = yield args, kwargs yield outputs, compute_statistic(outputs) wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun) def new_out_axes_thunk(): old_out_axes = params['out_axes_thunk']() return compute_new_out_axes(old_out_axes(), output_statistic()) primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk)) ``` The reason why we have to structure the code this way is that we can only specify a new `out_axes_thunk` before we bind the primitive, but we need the outputs of bind to know how to update the `out_axes_thunk`. To make things worse, the implementation of `bind` is allowed to make a call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_. This means that we cannot compute the output statistic in the implementation of the transformation, but we have to use an extra `lu.transformation_with_aux` for that (this populates the statistic store immediately after `wrapped_fun` is evaluated). The `compute_statistic` function depends on the transform in question. E.g. in the JVP trace it counts the number of non-zero tangent results. The situation is of course further complicated when we take `post_process_map` into account. The new `process_env_traces` now always sets up this funny time travel trampoline just in case it ends up being necessary, and `post_process_map` is now expected to return `(outputs, (todo, out_axes_transform))` instead of just `(outputs, todo)`. *** Compilation cache *** Because the `out_axes_thunk`s are now arguments to a _global_ compilation cache (in the form of `lu.cache` decorator on `parallel_callable`), we have to ensure that they implement `hash` and `==`. This is what forces us to add some slightly weird helpers such as `_hashable_function` and `_ignore_elem_list`. The code that uses those makes an assumption that the output pytree depends deterministically on the identity of the wrapped function, which I think is in line with general JAX assumptions. Otherwise the cache would depend on the identity of the thunk, which changes with every function invocation. Relaxing the global constraint on the cache (e.g. allowing each `pmap(f)` instance to have a separate cache) would make this easier too. * Why final style? * Now, making the primitives initial-style would remove the necessity for thunking, because we could have obtained the output pytree right when the function is wrapped. I assumed there is a good argument for making `pmap` pretend that it's a final-style primitive, but I'm not sure why that is? I hope it's something better than just avoiding a single jaxpr tracing.
2020-11-09 17:23:16 +00:00
"shapes": shapes,
"vmap_in_axes": vmap_in_axes, "vmap_out_axes": vmap_out_axes,
"pmap_in_axes": pmap_in_axes, "pmap_out_axes": pmap_out_axes
} for arg_shapes in s(compatible_shapes)
for num_args in s(range(1, 4))
for shapes in s(list(it.combinations_with_replacement(arg_shapes, num_args)))
for vmap_in_axes in s(all_bdims(*shapes, pmap=False))
for pmap_in_axes in s(all_bdims(*shapes, pmap=True))
for vmap_out_axes in s(out_bdims(shapes[0], False))
for pmap_out_axes in s(out_bdims(shapes[0], True))
)))
Add support for non-zero (but still not-None) out_axes in pmap Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`), but its semantics would match the specification of `out_axes=0` (i.e. all outputs should be stacked along the first axis). This patch makes it possible to specify non-zero values for out_axes, but more importantly it lays down the groundwork for `xmap` which will have to use some extremely similar (if not the same) code paths. One thing to note is that when I started this implementation I was also planning to add support for `out_axes=None`, which would allow us to stop using the `unbroadcast` hack, and most of the code is written with that in mind. Unfortunately it turned out that the correct implementation of the transpose rule for maps that do allow unmapped outputs would require me to pretty much simulate what avals-with-names is supposed to achieve. Technically replicated outputs should work today, for as long as the user does not do reverse-mode AD of `pmap`. But I decided that it's better to just disable them altogether until we can get the full and correct behavior. * Implementation details * This patch is significantly more involved than the one that implemented general `in_axes` support. That previous one at least had the foundation of `mapped_invars` which already behaved pretty similarly to general `in_axes`. From a quick glance one might think that `out_axes` should behave similarly to `in_axes`, but it turns out that this is not the case, at least not if we're interested in keeping those primitives final-style. ** Thunking ** The biggest difficulty with handling `out_axes` in final style primitives is that we want to treat them as a prefix of the output pytree, but we don't know the structure of the output pytree until the user function is evaluated! And the user function is not evaluated until we've applied all transforms and reached the impl rule! The solution to this problem is "straightforward": instead of putting `out_axes` as a primitive parameter, we bundle an `out_axes_thunk` which can only be called successfully after the wrapped function has been executed. The thunk returns a list of flat `out_axes`, expanded to the output pytree. However, the thunking presents us with two problems: *** Transformations *** Each transformation that modifies the number of outputs needs to ensure that the thunk is updated to reflect the new values. To make things worse a lot of the transforms can learn the number of added outputs _only after the wrapped function is evaluated_, which leads to the following "time travel" pattern that can be found in most `Trace`s: ```py @lu.transformation_with_aux def compute_output_statistic(*args, **kwargs): outputs = yield args, kwargs yield outputs, compute_statistic(outputs) wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun) def new_out_axes_thunk(): old_out_axes = params['out_axes_thunk']() return compute_new_out_axes(old_out_axes(), output_statistic()) primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk)) ``` The reason why we have to structure the code this way is that we can only specify a new `out_axes_thunk` before we bind the primitive, but we need the outputs of bind to know how to update the `out_axes_thunk`. To make things worse, the implementation of `bind` is allowed to make a call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_. This means that we cannot compute the output statistic in the implementation of the transformation, but we have to use an extra `lu.transformation_with_aux` for that (this populates the statistic store immediately after `wrapped_fun` is evaluated). The `compute_statistic` function depends on the transform in question. E.g. in the JVP trace it counts the number of non-zero tangent results. The situation is of course further complicated when we take `post_process_map` into account. The new `process_env_traces` now always sets up this funny time travel trampoline just in case it ends up being necessary, and `post_process_map` is now expected to return `(outputs, (todo, out_axes_transform))` instead of just `(outputs, todo)`. *** Compilation cache *** Because the `out_axes_thunk`s are now arguments to a _global_ compilation cache (in the form of `lu.cache` decorator on `parallel_callable`), we have to ensure that they implement `hash` and `==`. This is what forces us to add some slightly weird helpers such as `_hashable_function` and `_ignore_elem_list`. The code that uses those makes an assumption that the output pytree depends deterministically on the identity of the wrapped function, which I think is in line with general JAX assumptions. Otherwise the cache would depend on the identity of the thunk, which changes with every function invocation. Relaxing the global constraint on the cache (e.g. allowing each `pmap(f)` instance to have a separate cache) would make this easier too. * Why final style? * Now, making the primitives initial-style would remove the necessity for thunking, because we could have obtained the output pytree right when the function is wrapped. I assumed there is a good argument for making `pmap` pretend that it's a final-style primitive, but I'm not sure why that is? I hope it's something better than just avoiding a single jaxpr tracing.
2020-11-09 17:23:16 +00:00
def testVmapOfPmap(self, shapes, vmap_in_axes, pmap_in_axes, vmap_out_axes, pmap_out_axes):
vmapped_size = 3
pmapped_size = jax.device_count()
rng = jtu.rand_default(self.rng())
def fun(*args):
return sum(args)
Add support for non-zero (but still not-None) out_axes in pmap Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`), but its semantics would match the specification of `out_axes=0` (i.e. all outputs should be stacked along the first axis). This patch makes it possible to specify non-zero values for out_axes, but more importantly it lays down the groundwork for `xmap` which will have to use some extremely similar (if not the same) code paths. One thing to note is that when I started this implementation I was also planning to add support for `out_axes=None`, which would allow us to stop using the `unbroadcast` hack, and most of the code is written with that in mind. Unfortunately it turned out that the correct implementation of the transpose rule for maps that do allow unmapped outputs would require me to pretty much simulate what avals-with-names is supposed to achieve. Technically replicated outputs should work today, for as long as the user does not do reverse-mode AD of `pmap`. But I decided that it's better to just disable them altogether until we can get the full and correct behavior. * Implementation details * This patch is significantly more involved than the one that implemented general `in_axes` support. That previous one at least had the foundation of `mapped_invars` which already behaved pretty similarly to general `in_axes`. From a quick glance one might think that `out_axes` should behave similarly to `in_axes`, but it turns out that this is not the case, at least not if we're interested in keeping those primitives final-style. ** Thunking ** The biggest difficulty with handling `out_axes` in final style primitives is that we want to treat them as a prefix of the output pytree, but we don't know the structure of the output pytree until the user function is evaluated! And the user function is not evaluated until we've applied all transforms and reached the impl rule! The solution to this problem is "straightforward": instead of putting `out_axes` as a primitive parameter, we bundle an `out_axes_thunk` which can only be called successfully after the wrapped function has been executed. The thunk returns a list of flat `out_axes`, expanded to the output pytree. However, the thunking presents us with two problems: *** Transformations *** Each transformation that modifies the number of outputs needs to ensure that the thunk is updated to reflect the new values. To make things worse a lot of the transforms can learn the number of added outputs _only after the wrapped function is evaluated_, which leads to the following "time travel" pattern that can be found in most `Trace`s: ```py @lu.transformation_with_aux def compute_output_statistic(*args, **kwargs): outputs = yield args, kwargs yield outputs, compute_statistic(outputs) wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun) def new_out_axes_thunk(): old_out_axes = params['out_axes_thunk']() return compute_new_out_axes(old_out_axes(), output_statistic()) primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk)) ``` The reason why we have to structure the code this way is that we can only specify a new `out_axes_thunk` before we bind the primitive, but we need the outputs of bind to know how to update the `out_axes_thunk`. To make things worse, the implementation of `bind` is allowed to make a call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_. This means that we cannot compute the output statistic in the implementation of the transformation, but we have to use an extra `lu.transformation_with_aux` for that (this populates the statistic store immediately after `wrapped_fun` is evaluated). The `compute_statistic` function depends on the transform in question. E.g. in the JVP trace it counts the number of non-zero tangent results. The situation is of course further complicated when we take `post_process_map` into account. The new `process_env_traces` now always sets up this funny time travel trampoline just in case it ends up being necessary, and `post_process_map` is now expected to return `(outputs, (todo, out_axes_transform))` instead of just `(outputs, todo)`. *** Compilation cache *** Because the `out_axes_thunk`s are now arguments to a _global_ compilation cache (in the form of `lu.cache` decorator on `parallel_callable`), we have to ensure that they implement `hash` and `==`. This is what forces us to add some slightly weird helpers such as `_hashable_function` and `_ignore_elem_list`. The code that uses those makes an assumption that the output pytree depends deterministically on the identity of the wrapped function, which I think is in line with general JAX assumptions. Otherwise the cache would depend on the identity of the thunk, which changes with every function invocation. Relaxing the global constraint on the cache (e.g. allowing each `pmap(f)` instance to have a separate cache) would make this easier too. * Why final style? * Now, making the primitives initial-style would remove the necessity for thunking, because we could have obtained the output pytree right when the function is wrapped. I assumed there is a good argument for making `pmap` pretend that it's a final-style primitive, but I'm not sure why that is? I hope it's something better than just avoiding a single jaxpr tracing.
2020-11-09 17:23:16 +00:00
final_shapes = map(partial(add_bdim, vmapped_size), vmap_in_axes,
map(partial(add_bdim, pmapped_size), pmap_in_axes, shapes))
def args_slice(vi, pi):
return args_slicer(args_slicer(args, vmap_in_axes)(vi), pmap_in_axes)(pi)
args = [rng(shape, jnp.float32) for shape in final_shapes]
Add support for non-zero (but still not-None) out_axes in pmap Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`), but its semantics would match the specification of `out_axes=0` (i.e. all outputs should be stacked along the first axis). This patch makes it possible to specify non-zero values for out_axes, but more importantly it lays down the groundwork for `xmap` which will have to use some extremely similar (if not the same) code paths. One thing to note is that when I started this implementation I was also planning to add support for `out_axes=None`, which would allow us to stop using the `unbroadcast` hack, and most of the code is written with that in mind. Unfortunately it turned out that the correct implementation of the transpose rule for maps that do allow unmapped outputs would require me to pretty much simulate what avals-with-names is supposed to achieve. Technically replicated outputs should work today, for as long as the user does not do reverse-mode AD of `pmap`. But I decided that it's better to just disable them altogether until we can get the full and correct behavior. * Implementation details * This patch is significantly more involved than the one that implemented general `in_axes` support. That previous one at least had the foundation of `mapped_invars` which already behaved pretty similarly to general `in_axes`. From a quick glance one might think that `out_axes` should behave similarly to `in_axes`, but it turns out that this is not the case, at least not if we're interested in keeping those primitives final-style. ** Thunking ** The biggest difficulty with handling `out_axes` in final style primitives is that we want to treat them as a prefix of the output pytree, but we don't know the structure of the output pytree until the user function is evaluated! And the user function is not evaluated until we've applied all transforms and reached the impl rule! The solution to this problem is "straightforward": instead of putting `out_axes` as a primitive parameter, we bundle an `out_axes_thunk` which can only be called successfully after the wrapped function has been executed. The thunk returns a list of flat `out_axes`, expanded to the output pytree. However, the thunking presents us with two problems: *** Transformations *** Each transformation that modifies the number of outputs needs to ensure that the thunk is updated to reflect the new values. To make things worse a lot of the transforms can learn the number of added outputs _only after the wrapped function is evaluated_, which leads to the following "time travel" pattern that can be found in most `Trace`s: ```py @lu.transformation_with_aux def compute_output_statistic(*args, **kwargs): outputs = yield args, kwargs yield outputs, compute_statistic(outputs) wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun) def new_out_axes_thunk(): old_out_axes = params['out_axes_thunk']() return compute_new_out_axes(old_out_axes(), output_statistic()) primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk)) ``` The reason why we have to structure the code this way is that we can only specify a new `out_axes_thunk` before we bind the primitive, but we need the outputs of bind to know how to update the `out_axes_thunk`. To make things worse, the implementation of `bind` is allowed to make a call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_. This means that we cannot compute the output statistic in the implementation of the transformation, but we have to use an extra `lu.transformation_with_aux` for that (this populates the statistic store immediately after `wrapped_fun` is evaluated). The `compute_statistic` function depends on the transform in question. E.g. in the JVP trace it counts the number of non-zero tangent results. The situation is of course further complicated when we take `post_process_map` into account. The new `process_env_traces` now always sets up this funny time travel trampoline just in case it ends up being necessary, and `post_process_map` is now expected to return `(outputs, (todo, out_axes_transform))` instead of just `(outputs, todo)`. *** Compilation cache *** Because the `out_axes_thunk`s are now arguments to a _global_ compilation cache (in the form of `lu.cache` decorator on `parallel_callable`), we have to ensure that they implement `hash` and `==`. This is what forces us to add some slightly weird helpers such as `_hashable_function` and `_ignore_elem_list`. The code that uses those makes an assumption that the output pytree depends deterministically on the identity of the wrapped function, which I think is in line with general JAX assumptions. Otherwise the cache would depend on the identity of the thunk, which changes with every function invocation. Relaxing the global constraint on the cache (e.g. allowing each `pmap(f)` instance to have a separate cache) would make this easier too. * Why final style? * Now, making the primitives initial-style would remove the necessity for thunking, because we could have obtained the output pytree right when the function is wrapped. I assumed there is a good argument for making `pmap` pretend that it's a final-style primitive, but I'm not sure why that is? I hope it's something better than just avoiding a single jaxpr tracing.
2020-11-09 17:23:16 +00:00
ans = vmap(pmap(fun, in_axes=pmap_in_axes, out_axes=pmap_out_axes),
in_axes=vmap_in_axes,
out_axes=vmap_out_axes)(*args)
expected = np.stack(
[np.stack([fun(*args_slice(vi, pi)) for pi in range(pmapped_size)], axis=pmap_out_axes)
for vi in range(vmapped_size)],
axis=vmap_out_axes)
self.assertAllClose(ans, expected)
Add support for non-zero (but still not-None) out_axes in pmap Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`), but its semantics would match the specification of `out_axes=0` (i.e. all outputs should be stacked along the first axis). This patch makes it possible to specify non-zero values for out_axes, but more importantly it lays down the groundwork for `xmap` which will have to use some extremely similar (if not the same) code paths. One thing to note is that when I started this implementation I was also planning to add support for `out_axes=None`, which would allow us to stop using the `unbroadcast` hack, and most of the code is written with that in mind. Unfortunately it turned out that the correct implementation of the transpose rule for maps that do allow unmapped outputs would require me to pretty much simulate what avals-with-names is supposed to achieve. Technically replicated outputs should work today, for as long as the user does not do reverse-mode AD of `pmap`. But I decided that it's better to just disable them altogether until we can get the full and correct behavior. * Implementation details * This patch is significantly more involved than the one that implemented general `in_axes` support. That previous one at least had the foundation of `mapped_invars` which already behaved pretty similarly to general `in_axes`. From a quick glance one might think that `out_axes` should behave similarly to `in_axes`, but it turns out that this is not the case, at least not if we're interested in keeping those primitives final-style. ** Thunking ** The biggest difficulty with handling `out_axes` in final style primitives is that we want to treat them as a prefix of the output pytree, but we don't know the structure of the output pytree until the user function is evaluated! And the user function is not evaluated until we've applied all transforms and reached the impl rule! The solution to this problem is "straightforward": instead of putting `out_axes` as a primitive parameter, we bundle an `out_axes_thunk` which can only be called successfully after the wrapped function has been executed. The thunk returns a list of flat `out_axes`, expanded to the output pytree. However, the thunking presents us with two problems: *** Transformations *** Each transformation that modifies the number of outputs needs to ensure that the thunk is updated to reflect the new values. To make things worse a lot of the transforms can learn the number of added outputs _only after the wrapped function is evaluated_, which leads to the following "time travel" pattern that can be found in most `Trace`s: ```py @lu.transformation_with_aux def compute_output_statistic(*args, **kwargs): outputs = yield args, kwargs yield outputs, compute_statistic(outputs) wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun) def new_out_axes_thunk(): old_out_axes = params['out_axes_thunk']() return compute_new_out_axes(old_out_axes(), output_statistic()) primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk)) ``` The reason why we have to structure the code this way is that we can only specify a new `out_axes_thunk` before we bind the primitive, but we need the outputs of bind to know how to update the `out_axes_thunk`. To make things worse, the implementation of `bind` is allowed to make a call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_. This means that we cannot compute the output statistic in the implementation of the transformation, but we have to use an extra `lu.transformation_with_aux` for that (this populates the statistic store immediately after `wrapped_fun` is evaluated). The `compute_statistic` function depends on the transform in question. E.g. in the JVP trace it counts the number of non-zero tangent results. The situation is of course further complicated when we take `post_process_map` into account. The new `process_env_traces` now always sets up this funny time travel trampoline just in case it ends up being necessary, and `post_process_map` is now expected to return `(outputs, (todo, out_axes_transform))` instead of just `(outputs, todo)`. *** Compilation cache *** Because the `out_axes_thunk`s are now arguments to a _global_ compilation cache (in the form of `lu.cache` decorator on `parallel_callable`), we have to ensure that they implement `hash` and `==`. This is what forces us to add some slightly weird helpers such as `_hashable_function` and `_ignore_elem_list`. The code that uses those makes an assumption that the output pytree depends deterministically on the identity of the wrapped function, which I think is in line with general JAX assumptions. Otherwise the cache would depend on the identity of the thunk, which changes with every function invocation. Relaxing the global constraint on the cache (e.g. allowing each `pmap(f)` instance to have a separate cache) would make this easier too. * Why final style? * Now, making the primitives initial-style would remove the necessity for thunking, because we could have obtained the output pytree right when the function is wrapped. I assumed there is a good argument for making `pmap` pretend that it's a final-style primitive, but I'm not sure why that is? I hope it's something better than just avoiding a single jaxpr tracing.
2020-11-09 17:23:16 +00:00
class VmapPmapCollectivesTest(jtu.JaxTestCase):
@parameterized.named_parameters(
{"testcase_name": "_collective={}".format(collective.__name__).replace(" ", ""),
"collective": collective}
for collective in [lax.psum, lax.pmean, lax.pmax, lax.pmin])
def testCollectivesWithVmap(self, collective):
def f(map1, map2):
@partial(map1, axis_name='i')
@partial(map2, axis_name='j')
def f(x, y):
return x + collective(x.dot(y), ('i', 'j'))
return f
if jax.device_count() < 4:
raise SkipTest("test requires at least four devices")
x = jnp.ones((2, 2, 64, 64))
y = f(jax.pmap, jax.pmap)(x, x)
self.assertAllClose(f(jax.vmap, jax.vmap)(x, x), y)
self.assertAllClose(f(jax.pmap, jax.vmap)(x, x), y)
self.assertAllClose(f(jax.vmap, jax.pmap)(x, x), y)
@parameterized.named_parameters(
{"testcase_name": "_collective={}".format(collective.__name__).replace(" ", ""),
"collective": collective}
for collective in [lax.psum, lax.pmean, lax.pmax, lax.pmin])
def testCollectivesWithVmap2(self, collective):
def f(map1, map2):
@partial(map1, axis_name='i')
@partial(map2, axis_name='j')
def f(x, y):
return x + collective(x.dot(y), ('i', 'j'))
return f
if jax.device_count() < 8:
raise SkipTest("test requires at least eight devices")
x = jnp.arange(4*2*64*64).reshape(4, 2, 64, 64)
y = f(jax.pmap, jax.pmap)(x, x)
self.assertAllClose(f(jax.vmap, jax.vmap)(x, x), y)
self.assertAllClose(f(jax.pmap, jax.vmap)(x, x), y)
self.assertAllClose(f(jax.vmap, jax.pmap)(x, x), y)
def testPPermuteWithVmap(self):
perm = [(0, 1), (1, 0)]
def f(map2):
@partial(jax.pmap, axis_name='i')
@partial(map2)
def f(x, y):
return x + jax.lax.ppermute(x.dot(y), 'i', perm)
return f
if jax.device_count() < 4:
raise SkipTest("test requires at least four devices")
x = jnp.ones((2, 2, 64, 64))
self.assertAllClose(f(jax.pmap)(x, x), f(jax.vmap)(x, x))
def testPPermuteAgreesWithVmap(self):
if jax.device_count() < 3:
raise SkipTest("test requires at least three devices")
def f(x):
return lax.ppermute(x, 'i', [[1, 0], [2, 1], [0, 2]])
xs = jnp.arange(3) * 10
ys = jax.pmap(f, axis_name='i')(xs)
zs = jax.vmap(f, axis_name='i')(xs)
self.assertAllClose(ys, zs, check_dtypes=True)
@parameterized.named_parameters(
{"testcase_name": f"_split={split_axis}_concat={concat_axis}_vmap={vmap_axis}",
"split_axis": split_axis, "concat_axis": concat_axis, "vmap_axis": vmap_axis}
for split_axis, concat_axis, vmap_axis in it.product(range(3), range(3), range(4)))
@ignore_slow_all_to_all_warning()
def testAllToAllInVmap(self, split_axis, concat_axis, vmap_axis):
def f(x):
return lax.all_to_all(x, 'i', split_axis=split_axis, concat_axis=concat_axis)
def adj(axis, hidden_axes):
for hax in sorted(hidden_axes):
if hax <= axis:
axis += 1
return axis
def reference(x, split_axis, concat_axis, vmap_axis):
pmap_axis = 0
vmap_axis = adj(vmap_axis, [pmap_axis])
ref = x
# Step 1.
# Adjust the split axis to the real tensor layout and move it to
# position 1. Since pmap_axis is always 0 we don't have to adjust it,
# but we do have to adjust vmap_axis.
split_axis = adj(split_axis, [pmap_axis, vmap_axis])
ref = jnp.moveaxis(ref, split_axis, pmap_axis + 1)
vmap_axis = vmap_axis + (0 if split_axis < vmap_axis else 1)
split_axis = pmap_axis + 1 # split_axes == 1
# Step 2.
# Now, we move pmap_axis to the position indicated by concat_axis.
concat_axis = adj(concat_axis, [pmap_axis, split_axis, vmap_axis]) - 1
ref = jnp.moveaxis(ref, pmap_axis, concat_axis)
pmap_axis = 0
vmap_axis = vmap_axis - (1 if concat_axis >= vmap_axis else 0)
del split_axis, concat_axis
# Step 3. vmap_axis always ends in position 1, since out_axes=0.
ref = jnp.moveaxis(ref, vmap_axis, 1)
return ref
def verify_ref():
# Both the reference and the real implementation of all_to_all batching involve
# some pretty complicated axis arithmetic, so it would be good to verify that it's
# not the case that the test passes because they're both incorrect. Fortunately, it
# is quite easy to write out the shape function for this code, and we know
# that it should be equivalent to a bunch of transposes, so the code below verifies
# that the reference puts the right dimensions in the right places. Note that we
# can't do the same comparison on f, since all_to_all wouldn't allow us to swap axes of
# different sizes.
start_shape = [2, 3, 4, 5, 6]
instance_shape = start_shape.copy()
pmap_dim_id = instance_shape.pop(0)
vmap_dim_id = instance_shape.pop(vmap_axis)
split_axis_id = instance_shape.pop(split_axis)
instance_shape.insert(concat_axis, pmap_dim_id)
expected_shape = (split_axis_id, vmap_dim_id, *instance_shape)
x = np.empty(start_shape)
self.assertEqual(reference(x, split_axis, concat_axis, vmap_axis).shape,
expected_shape)
verify_ref()
shape = (jax.device_count(),) * 5
x = jnp.arange(np.prod(shape)).reshape(shape)
self.assertAllClose(pmap(vmap(f, in_axes=vmap_axis), axis_name='i')(x),
reference(x, split_axis, concat_axis, vmap_axis))
@parameterized.named_parameters(
{"testcase_name": f"_split={split_axis}_concat={concat_axis}",
"split_axis": split_axis, "concat_axis": concat_axis}
for split_axis, concat_axis in it.product(range(3), range(3)))
@ignore_slow_all_to_all_warning()
def testAllToAllVsVmap(self, split_axis, concat_axis):
def f(x):
return lax.all_to_all(x, 'i', split_axis=split_axis, concat_axis=concat_axis)
shape = (jax.device_count(),) * 4
x = jnp.arange(np.prod(shape)).reshape(shape)
self.assertAllClose(pmap(f, axis_name='i')(x),
vmap(f, axis_name='i')(x))
@parameterized.named_parameters(
{"testcase_name": f"_split={split_axis}_concat={concat_axis}_axes={''.join(axes)}",
"axes": axes, "split_axis": split_axis, "concat_axis": concat_axis}
for axes, split_axis, concat_axis
in it.product([('i', 'j'), ('j', 'i')], range(3), range(3)))
@ignore_slow_all_to_all_warning()
@unittest.skip("multi-axis all_to_all broken after #4835") # TODO(mattjj,apaszke)
def testAllToAllMultipleAxesVsVmap(self, axes, split_axis, concat_axis):
if jax.device_count() < 4:
raise SkipTest("test requires at least four devices")
def f(x):
return lax.all_to_all(x, axes, split_axis=split_axis, concat_axis=concat_axis)
shape = (2, 2, 4, 4, 4)
x = jnp.arange(np.prod(shape)).reshape(shape)
self.assertAllClose(pmap(pmap(f, axis_name='j'), axis_name='i')(x),
vmap(vmap(f, axis_name='j'), axis_name='i')(x))
def testAllGatherWithVmap(self):
def f(map2):
@partial(jax.pmap, axis_name='i')
@partial(map2)
def f(x):
return jax.lax.all_gather(x, 'i')
return f
if jax.device_count() < 4:
raise SkipTest("test requires at least four devices")
x = jnp.ones((2, 2, 64, 64))
self.assertAllClose(f(jax.pmap)(x), f(jax.vmap)(x))
class PmapWithDevicesTest(jtu.JaxTestCase):
def testAllDevices(self):
f = pmap(lambda x: x - lax.psum(x, 'i'), axis_name='i',
devices=jax.devices())
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
expected = x - np.sum(x, 0)
ans = f(x)
self.assertAllClose(ans, expected)
def testOneDevice(self):
if jax.device_count() == 1:
raise SkipTest("this test requires multiple devices")
d0 = jax.devices()[0]
d1 = jax.devices()[1]
f = lambda x: jnp.dot(x, x.T)
f0 = pmap(f, devices=[d0])
f1 = pmap(f, devices=[d1])
x = self.rng().rand(1, 1000, 1000)
r0 = f0(x)
r1 = f1(x)
expected = np.expand_dims(np.dot(x.squeeze(), x.squeeze().T), 0)
self.assertAllClose(r0, expected, atol=1e-6, rtol=1e-3)
self.assertAllClose(r1, expected, atol=1e-6, rtol=1e-3)
def testNoDevicesError(self):
f = pmap(lambda x: x - lax.psum(x, 'i'), axis_name='i', devices=[])
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
with self.assertRaisesRegex(
ValueError, "'devices' argument to pmap must be non-empty, or None."):
f(x)
def testBadAxisSizeError(self):
if jax.device_count() == 1:
raise SkipTest("this test requires multiple devices")
f = pmap(lambda x: lax.psum(x, 'i'), axis_name='i',
devices=jax.devices())
with self.assertRaisesRegex(
2019-09-27 11:50:21 -07:00
ValueError, r"Leading axis size of input to pmapped function must "
r"equal the number of local devices passed to pmap. Got axis_size=1, "
r"num_local_devices=\d."):
f(jnp.ones(1))
with self.assertRaisesRegex(
2019-09-27 11:50:21 -07:00
ValueError, r"Leading axis size of input to pmapped function must "
r"equal the number of local devices passed to pmap. Got axis_size=\d, "
r"num_local_devices=\d."):
f(jnp.ones(jax.device_count() + 1))
def testBadAxisSizeErrorNested(self):
f = pmap(pmap(lambda x: lax.psum(x, ('i', 'j')),
axis_name='j'),
axis_name='i',
devices=[jax.local_devices()[0]])
with self.assertRaisesRegex(
ValueError,
r"pmapped function requires 4 local devices to run due to nested "
r"pmapped or other parallel functions, but only 1 are available."):
f(jnp.ones((1, 4)))
def testNestedPmaps(self):
if jax.device_count() % 2 != 0:
raise SkipTest
# Devices specified in outer pmap are OK
@partial(pmap, axis_name='i', devices=jax.devices())
def foo(x):
@partial(pmap, axis_name='j')
def bar(y):
return lax.psum(y, 'j')
return bar(x)
x = jnp.ones((jax.device_count() // 2, 2))
ans = foo(x)
expected = x * 2
self.assertAllClose(ans, expected)
def testNestedPmapsBools(self):
if jax.device_count() % 2 != 0:
raise SkipTest
# Devices specified in outer pmap are OK
@partial(pmap, axis_name='i', devices=jax.devices())
def foo(x):
@partial(pmap, axis_name='j')
def bar(y):
return jnp.logical_not(y)
return bar(x)
x = jnp.ones((jax.device_count() // 2, 2), jnp.bool_)
ans = foo(x)
expected = jnp.zeros((jax.device_count() // 2, 2), jnp.bool_)
self.assertAllClose(ans, expected)
def testNestedPmapsError(self):
# Devices specified in inner pmap not OK
@partial(pmap, axis_name='i')
def foo(x):
@partial(pmap, axis_name='j', devices=jax.devices())
def bar(y):
return lax.psum(y, 'j')
return bar(x)
with self.assertRaisesRegex(
ValueError,
"Nested pmap with explicit devices argument."):
foo(jnp.ones((jax.device_count(), 1)))
def testJitInPmap(self):
@partial(pmap, axis_name='i', devices=jax.devices())
def foo(x):
@jit
def bar(y):
return y + 1
return lax.psum(bar(x), 'i')
ndevices = jax.device_count()
ans = foo(jnp.ones((ndevices, 1)))
expected = np.ones((ndevices, 1), dtype=jnp.float_) * ndevices * 2
self.assertAllClose(ans, expected)
@ignore_jit_of_pmap_warning()
def testPmapInJit(self):
@jit
def foo(x):
@partial(pmap, axis_name='i', devices=jax.devices())
def bar(y):
return lax.psum(y, 'i')
return bar(x)
ndevices = jax.device_count()
ans = foo(jnp.ones((ndevices, 1)))
expected = np.ones((ndevices, 1), dtype=jnp.float_) * ndevices
self.assertAllClose(ans, expected)
def testGradBasic(self):
@partial(pmap, axis_name='i', devices=jax.devices())
def f(x):
return jnp.sin(x)
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
ans = grad(lambda x: jnp.sum(jnp.sin(x)))(x)
expected = grad(lambda x: jnp.sum(f(x)))(x)
self.assertAllClose(ans, expected, check_dtypes=False)
def testPmapStaticArgnums(self):
@partial(pmap, axis_name='i', static_broadcasted_argnums=1)
def f(x, y):
return jnp.sin(x + y())
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
y = lambda: 3.
ans = f(x, y)
expected = np.sin(x + 3.)
self.assertAllClose(ans, expected, check_dtypes=False)
Add support for non-zero (but still not-None) out_axes in pmap Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`), but its semantics would match the specification of `out_axes=0` (i.e. all outputs should be stacked along the first axis). This patch makes it possible to specify non-zero values for out_axes, but more importantly it lays down the groundwork for `xmap` which will have to use some extremely similar (if not the same) code paths. One thing to note is that when I started this implementation I was also planning to add support for `out_axes=None`, which would allow us to stop using the `unbroadcast` hack, and most of the code is written with that in mind. Unfortunately it turned out that the correct implementation of the transpose rule for maps that do allow unmapped outputs would require me to pretty much simulate what avals-with-names is supposed to achieve. Technically replicated outputs should work today, for as long as the user does not do reverse-mode AD of `pmap`. But I decided that it's better to just disable them altogether until we can get the full and correct behavior. * Implementation details * This patch is significantly more involved than the one that implemented general `in_axes` support. That previous one at least had the foundation of `mapped_invars` which already behaved pretty similarly to general `in_axes`. From a quick glance one might think that `out_axes` should behave similarly to `in_axes`, but it turns out that this is not the case, at least not if we're interested in keeping those primitives final-style. ** Thunking ** The biggest difficulty with handling `out_axes` in final style primitives is that we want to treat them as a prefix of the output pytree, but we don't know the structure of the output pytree until the user function is evaluated! And the user function is not evaluated until we've applied all transforms and reached the impl rule! The solution to this problem is "straightforward": instead of putting `out_axes` as a primitive parameter, we bundle an `out_axes_thunk` which can only be called successfully after the wrapped function has been executed. The thunk returns a list of flat `out_axes`, expanded to the output pytree. However, the thunking presents us with two problems: *** Transformations *** Each transformation that modifies the number of outputs needs to ensure that the thunk is updated to reflect the new values. To make things worse a lot of the transforms can learn the number of added outputs _only after the wrapped function is evaluated_, which leads to the following "time travel" pattern that can be found in most `Trace`s: ```py @lu.transformation_with_aux def compute_output_statistic(*args, **kwargs): outputs = yield args, kwargs yield outputs, compute_statistic(outputs) wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun) def new_out_axes_thunk(): old_out_axes = params['out_axes_thunk']() return compute_new_out_axes(old_out_axes(), output_statistic()) primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk)) ``` The reason why we have to structure the code this way is that we can only specify a new `out_axes_thunk` before we bind the primitive, but we need the outputs of bind to know how to update the `out_axes_thunk`. To make things worse, the implementation of `bind` is allowed to make a call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_. This means that we cannot compute the output statistic in the implementation of the transformation, but we have to use an extra `lu.transformation_with_aux` for that (this populates the statistic store immediately after `wrapped_fun` is evaluated). The `compute_statistic` function depends on the transform in question. E.g. in the JVP trace it counts the number of non-zero tangent results. The situation is of course further complicated when we take `post_process_map` into account. The new `process_env_traces` now always sets up this funny time travel trampoline just in case it ends up being necessary, and `post_process_map` is now expected to return `(outputs, (todo, out_axes_transform))` instead of just `(outputs, todo)`. *** Compilation cache *** Because the `out_axes_thunk`s are now arguments to a _global_ compilation cache (in the form of `lu.cache` decorator on `parallel_callable`), we have to ensure that they implement `hash` and `==`. This is what forces us to add some slightly weird helpers such as `_hashable_function` and `_ignore_elem_list`. The code that uses those makes an assumption that the output pytree depends deterministically on the identity of the wrapped function, which I think is in line with general JAX assumptions. Otherwise the cache would depend on the identity of the thunk, which changes with every function invocation. Relaxing the global constraint on the cache (e.g. allowing each `pmap(f)` instance to have a separate cache) would make this easier too. * Why final style? * Now, making the primitives initial-style would remove the necessity for thunking, because we could have obtained the output pytree right when the function is wrapped. I assumed there is a good argument for making `pmap` pretend that it's a final-style primitive, but I'm not sure why that is? I hope it's something better than just avoiding a single jaxpr tracing.
2020-11-09 17:23:16 +00:00
def testPmapInAxesBasic(self):
@partial(pmap, in_axes=(1, 2))
def f(x, y):
return jnp.sin(x + y)
xshape = (2, jax.device_count(), 4)
x = np.arange(prod(xshape)).reshape(xshape)
yshape = (2, 4, jax.device_count())
y = np.arange(prod(yshape)).reshape(yshape)
self.assertAllClose(f(x, y),
jnp.sin(x.transpose((1, 0, 2)) + y.transpose((2, 0, 1))))
def testPmapInAxesGrad(self):
def f(x, y, z):
return jnp.sin(x + y + z)
fp = pmap(f, in_axes=(1, 2, None))
fv = vmap(f, in_axes=(1, 2, None))
xshape = (5, jax.device_count(), 7)
x = np.arange(prod(xshape), dtype=np.float32).reshape(xshape)
yshape = (5, 7, jax.device_count())
y = np.arange(prod(yshape), dtype=np.float32).reshape(yshape)
zshape = (5, 7)
z = np.arange(prod(zshape), dtype=np.float32).reshape(zshape)
dx, dy, dz = jax.grad(lambda args: fp(*args).sum())((x, y, z))
assert dx.shape == xshape
assert dy.shape == yshape
assert dz.shape == zshape
self.assertAllClose(jax.grad(lambda args: fp(*args).sum())((x, y, z)),
jax.grad(lambda args: fv(*args).sum())((x, y, z)))
Add support for non-zero (but still not-None) out_axes in pmap Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`), but its semantics would match the specification of `out_axes=0` (i.e. all outputs should be stacked along the first axis). This patch makes it possible to specify non-zero values for out_axes, but more importantly it lays down the groundwork for `xmap` which will have to use some extremely similar (if not the same) code paths. One thing to note is that when I started this implementation I was also planning to add support for `out_axes=None`, which would allow us to stop using the `unbroadcast` hack, and most of the code is written with that in mind. Unfortunately it turned out that the correct implementation of the transpose rule for maps that do allow unmapped outputs would require me to pretty much simulate what avals-with-names is supposed to achieve. Technically replicated outputs should work today, for as long as the user does not do reverse-mode AD of `pmap`. But I decided that it's better to just disable them altogether until we can get the full and correct behavior. * Implementation details * This patch is significantly more involved than the one that implemented general `in_axes` support. That previous one at least had the foundation of `mapped_invars` which already behaved pretty similarly to general `in_axes`. From a quick glance one might think that `out_axes` should behave similarly to `in_axes`, but it turns out that this is not the case, at least not if we're interested in keeping those primitives final-style. ** Thunking ** The biggest difficulty with handling `out_axes` in final style primitives is that we want to treat them as a prefix of the output pytree, but we don't know the structure of the output pytree until the user function is evaluated! And the user function is not evaluated until we've applied all transforms and reached the impl rule! The solution to this problem is "straightforward": instead of putting `out_axes` as a primitive parameter, we bundle an `out_axes_thunk` which can only be called successfully after the wrapped function has been executed. The thunk returns a list of flat `out_axes`, expanded to the output pytree. However, the thunking presents us with two problems: *** Transformations *** Each transformation that modifies the number of outputs needs to ensure that the thunk is updated to reflect the new values. To make things worse a lot of the transforms can learn the number of added outputs _only after the wrapped function is evaluated_, which leads to the following "time travel" pattern that can be found in most `Trace`s: ```py @lu.transformation_with_aux def compute_output_statistic(*args, **kwargs): outputs = yield args, kwargs yield outputs, compute_statistic(outputs) wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun) def new_out_axes_thunk(): old_out_axes = params['out_axes_thunk']() return compute_new_out_axes(old_out_axes(), output_statistic()) primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk)) ``` The reason why we have to structure the code this way is that we can only specify a new `out_axes_thunk` before we bind the primitive, but we need the outputs of bind to know how to update the `out_axes_thunk`. To make things worse, the implementation of `bind` is allowed to make a call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_. This means that we cannot compute the output statistic in the implementation of the transformation, but we have to use an extra `lu.transformation_with_aux` for that (this populates the statistic store immediately after `wrapped_fun` is evaluated). The `compute_statistic` function depends on the transform in question. E.g. in the JVP trace it counts the number of non-zero tangent results. The situation is of course further complicated when we take `post_process_map` into account. The new `process_env_traces` now always sets up this funny time travel trampoline just in case it ends up being necessary, and `post_process_map` is now expected to return `(outputs, (todo, out_axes_transform))` instead of just `(outputs, todo)`. *** Compilation cache *** Because the `out_axes_thunk`s are now arguments to a _global_ compilation cache (in the form of `lu.cache` decorator on `parallel_callable`), we have to ensure that they implement `hash` and `==`. This is what forces us to add some slightly weird helpers such as `_hashable_function` and `_ignore_elem_list`. The code that uses those makes an assumption that the output pytree depends deterministically on the identity of the wrapped function, which I think is in line with general JAX assumptions. Otherwise the cache would depend on the identity of the thunk, which changes with every function invocation. Relaxing the global constraint on the cache (e.g. allowing each `pmap(f)` instance to have a separate cache) would make this easier too. * Why final style? * Now, making the primitives initial-style would remove the necessity for thunking, because we could have obtained the output pytree right when the function is wrapped. I assumed there is a good argument for making `pmap` pretend that it's a final-style primitive, but I'm not sure why that is? I hope it's something better than just avoiding a single jaxpr tracing.
2020-11-09 17:23:16 +00:00
def testPmapOutAxesBasic(self):
@partial(pmap, in_axes=(1, None), out_axes=(2, None))
def f(x, y):
return jnp.sin(x + y), y * 2
xshape = (2, jax.device_count(), 4)
Add support for non-zero (but still not-None) out_axes in pmap Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`), but its semantics would match the specification of `out_axes=0` (i.e. all outputs should be stacked along the first axis). This patch makes it possible to specify non-zero values for out_axes, but more importantly it lays down the groundwork for `xmap` which will have to use some extremely similar (if not the same) code paths. One thing to note is that when I started this implementation I was also planning to add support for `out_axes=None`, which would allow us to stop using the `unbroadcast` hack, and most of the code is written with that in mind. Unfortunately it turned out that the correct implementation of the transpose rule for maps that do allow unmapped outputs would require me to pretty much simulate what avals-with-names is supposed to achieve. Technically replicated outputs should work today, for as long as the user does not do reverse-mode AD of `pmap`. But I decided that it's better to just disable them altogether until we can get the full and correct behavior. * Implementation details * This patch is significantly more involved than the one that implemented general `in_axes` support. That previous one at least had the foundation of `mapped_invars` which already behaved pretty similarly to general `in_axes`. From a quick glance one might think that `out_axes` should behave similarly to `in_axes`, but it turns out that this is not the case, at least not if we're interested in keeping those primitives final-style. ** Thunking ** The biggest difficulty with handling `out_axes` in final style primitives is that we want to treat them as a prefix of the output pytree, but we don't know the structure of the output pytree until the user function is evaluated! And the user function is not evaluated until we've applied all transforms and reached the impl rule! The solution to this problem is "straightforward": instead of putting `out_axes` as a primitive parameter, we bundle an `out_axes_thunk` which can only be called successfully after the wrapped function has been executed. The thunk returns a list of flat `out_axes`, expanded to the output pytree. However, the thunking presents us with two problems: *** Transformations *** Each transformation that modifies the number of outputs needs to ensure that the thunk is updated to reflect the new values. To make things worse a lot of the transforms can learn the number of added outputs _only after the wrapped function is evaluated_, which leads to the following "time travel" pattern that can be found in most `Trace`s: ```py @lu.transformation_with_aux def compute_output_statistic(*args, **kwargs): outputs = yield args, kwargs yield outputs, compute_statistic(outputs) wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun) def new_out_axes_thunk(): old_out_axes = params['out_axes_thunk']() return compute_new_out_axes(old_out_axes(), output_statistic()) primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk)) ``` The reason why we have to structure the code this way is that we can only specify a new `out_axes_thunk` before we bind the primitive, but we need the outputs of bind to know how to update the `out_axes_thunk`. To make things worse, the implementation of `bind` is allowed to make a call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_. This means that we cannot compute the output statistic in the implementation of the transformation, but we have to use an extra `lu.transformation_with_aux` for that (this populates the statistic store immediately after `wrapped_fun` is evaluated). The `compute_statistic` function depends on the transform in question. E.g. in the JVP trace it counts the number of non-zero tangent results. The situation is of course further complicated when we take `post_process_map` into account. The new `process_env_traces` now always sets up this funny time travel trampoline just in case it ends up being necessary, and `post_process_map` is now expected to return `(outputs, (todo, out_axes_transform))` instead of just `(outputs, todo)`. *** Compilation cache *** Because the `out_axes_thunk`s are now arguments to a _global_ compilation cache (in the form of `lu.cache` decorator on `parallel_callable`), we have to ensure that they implement `hash` and `==`. This is what forces us to add some slightly weird helpers such as `_hashable_function` and `_ignore_elem_list`. The code that uses those makes an assumption that the output pytree depends deterministically on the identity of the wrapped function, which I think is in line with general JAX assumptions. Otherwise the cache would depend on the identity of the thunk, which changes with every function invocation. Relaxing the global constraint on the cache (e.g. allowing each `pmap(f)` instance to have a separate cache) would make this easier too. * Why final style? * Now, making the primitives initial-style would remove the necessity for thunking, because we could have obtained the output pytree right when the function is wrapped. I assumed there is a good argument for making `pmap` pretend that it's a final-style primitive, but I'm not sure why that is? I hope it's something better than just avoiding a single jaxpr tracing.
2020-11-09 17:23:16 +00:00
x = np.arange(prod(xshape)).reshape(xshape)
yshape = (2, 4)
y = np.arange(prod(yshape)).reshape(yshape)
self.assertAllClose(f(x, y),
(jnp.sin(x.transpose((1, 0, 2)) + y).transpose((1, 2, 0)), y * 2))
def testPmapDictOutAxes(self):
# see issue #6410
@partial(pmap, out_axes={'a': 0})
def f(x):
return {'a': x}
device_count = jax.device_count()
x = jnp.arange(device_count)
2022-04-01 14:51:54 -07:00
tree_util.tree_map(self.assertAllClose, f(x), {'a': x})
Add support for non-zero (but still not-None) out_axes in pmap Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`), but its semantics would match the specification of `out_axes=0` (i.e. all outputs should be stacked along the first axis). This patch makes it possible to specify non-zero values for out_axes, but more importantly it lays down the groundwork for `xmap` which will have to use some extremely similar (if not the same) code paths. One thing to note is that when I started this implementation I was also planning to add support for `out_axes=None`, which would allow us to stop using the `unbroadcast` hack, and most of the code is written with that in mind. Unfortunately it turned out that the correct implementation of the transpose rule for maps that do allow unmapped outputs would require me to pretty much simulate what avals-with-names is supposed to achieve. Technically replicated outputs should work today, for as long as the user does not do reverse-mode AD of `pmap`. But I decided that it's better to just disable them altogether until we can get the full and correct behavior. * Implementation details * This patch is significantly more involved than the one that implemented general `in_axes` support. That previous one at least had the foundation of `mapped_invars` which already behaved pretty similarly to general `in_axes`. From a quick glance one might think that `out_axes` should behave similarly to `in_axes`, but it turns out that this is not the case, at least not if we're interested in keeping those primitives final-style. ** Thunking ** The biggest difficulty with handling `out_axes` in final style primitives is that we want to treat them as a prefix of the output pytree, but we don't know the structure of the output pytree until the user function is evaluated! And the user function is not evaluated until we've applied all transforms and reached the impl rule! The solution to this problem is "straightforward": instead of putting `out_axes` as a primitive parameter, we bundle an `out_axes_thunk` which can only be called successfully after the wrapped function has been executed. The thunk returns a list of flat `out_axes`, expanded to the output pytree. However, the thunking presents us with two problems: *** Transformations *** Each transformation that modifies the number of outputs needs to ensure that the thunk is updated to reflect the new values. To make things worse a lot of the transforms can learn the number of added outputs _only after the wrapped function is evaluated_, which leads to the following "time travel" pattern that can be found in most `Trace`s: ```py @lu.transformation_with_aux def compute_output_statistic(*args, **kwargs): outputs = yield args, kwargs yield outputs, compute_statistic(outputs) wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun) def new_out_axes_thunk(): old_out_axes = params['out_axes_thunk']() return compute_new_out_axes(old_out_axes(), output_statistic()) primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk)) ``` The reason why we have to structure the code this way is that we can only specify a new `out_axes_thunk` before we bind the primitive, but we need the outputs of bind to know how to update the `out_axes_thunk`. To make things worse, the implementation of `bind` is allowed to make a call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_. This means that we cannot compute the output statistic in the implementation of the transformation, but we have to use an extra `lu.transformation_with_aux` for that (this populates the statistic store immediately after `wrapped_fun` is evaluated). The `compute_statistic` function depends on the transform in question. E.g. in the JVP trace it counts the number of non-zero tangent results. The situation is of course further complicated when we take `post_process_map` into account. The new `process_env_traces` now always sets up this funny time travel trampoline just in case it ends up being necessary, and `post_process_map` is now expected to return `(outputs, (todo, out_axes_transform))` instead of just `(outputs, todo)`. *** Compilation cache *** Because the `out_axes_thunk`s are now arguments to a _global_ compilation cache (in the form of `lu.cache` decorator on `parallel_callable`), we have to ensure that they implement `hash` and `==`. This is what forces us to add some slightly weird helpers such as `_hashable_function` and `_ignore_elem_list`. The code that uses those makes an assumption that the output pytree depends deterministically on the identity of the wrapped function, which I think is in line with general JAX assumptions. Otherwise the cache would depend on the identity of the thunk, which changes with every function invocation. Relaxing the global constraint on the cache (e.g. allowing each `pmap(f)` instance to have a separate cache) would make this easier too. * Why final style? * Now, making the primitives initial-style would remove the necessity for thunking, because we could have obtained the output pytree right when the function is wrapped. I assumed there is a good argument for making `pmap` pretend that it's a final-style primitive, but I'm not sure why that is? I hope it's something better than just avoiding a single jaxpr tracing.
2020-11-09 17:23:16 +00:00
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_{in_axes}_{out_axes}",
"in_axes": in_axes, "out_axes": out_axes}
for in_axes in all_bdims((3, 4), (3, 1), (1, 4), pmap=True)
for out_axes in out_bdims((3, 4), True)
))
def testPmapAllAxesGrad(self, in_axes, out_axes):
def f(x, y, z):
return jnp.sin(x + y) * z
pmapped_size = jax.device_count()
Add support for non-zero (but still not-None) out_axes in pmap Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`), but its semantics would match the specification of `out_axes=0` (i.e. all outputs should be stacked along the first axis). This patch makes it possible to specify non-zero values for out_axes, but more importantly it lays down the groundwork for `xmap` which will have to use some extremely similar (if not the same) code paths. One thing to note is that when I started this implementation I was also planning to add support for `out_axes=None`, which would allow us to stop using the `unbroadcast` hack, and most of the code is written with that in mind. Unfortunately it turned out that the correct implementation of the transpose rule for maps that do allow unmapped outputs would require me to pretty much simulate what avals-with-names is supposed to achieve. Technically replicated outputs should work today, for as long as the user does not do reverse-mode AD of `pmap`. But I decided that it's better to just disable them altogether until we can get the full and correct behavior. * Implementation details * This patch is significantly more involved than the one that implemented general `in_axes` support. That previous one at least had the foundation of `mapped_invars` which already behaved pretty similarly to general `in_axes`. From a quick glance one might think that `out_axes` should behave similarly to `in_axes`, but it turns out that this is not the case, at least not if we're interested in keeping those primitives final-style. ** Thunking ** The biggest difficulty with handling `out_axes` in final style primitives is that we want to treat them as a prefix of the output pytree, but we don't know the structure of the output pytree until the user function is evaluated! And the user function is not evaluated until we've applied all transforms and reached the impl rule! The solution to this problem is "straightforward": instead of putting `out_axes` as a primitive parameter, we bundle an `out_axes_thunk` which can only be called successfully after the wrapped function has been executed. The thunk returns a list of flat `out_axes`, expanded to the output pytree. However, the thunking presents us with two problems: *** Transformations *** Each transformation that modifies the number of outputs needs to ensure that the thunk is updated to reflect the new values. To make things worse a lot of the transforms can learn the number of added outputs _only after the wrapped function is evaluated_, which leads to the following "time travel" pattern that can be found in most `Trace`s: ```py @lu.transformation_with_aux def compute_output_statistic(*args, **kwargs): outputs = yield args, kwargs yield outputs, compute_statistic(outputs) wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun) def new_out_axes_thunk(): old_out_axes = params['out_axes_thunk']() return compute_new_out_axes(old_out_axes(), output_statistic()) primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk)) ``` The reason why we have to structure the code this way is that we can only specify a new `out_axes_thunk` before we bind the primitive, but we need the outputs of bind to know how to update the `out_axes_thunk`. To make things worse, the implementation of `bind` is allowed to make a call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_. This means that we cannot compute the output statistic in the implementation of the transformation, but we have to use an extra `lu.transformation_with_aux` for that (this populates the statistic store immediately after `wrapped_fun` is evaluated). The `compute_statistic` function depends on the transform in question. E.g. in the JVP trace it counts the number of non-zero tangent results. The situation is of course further complicated when we take `post_process_map` into account. The new `process_env_traces` now always sets up this funny time travel trampoline just in case it ends up being necessary, and `post_process_map` is now expected to return `(outputs, (todo, out_axes_transform))` instead of just `(outputs, todo)`. *** Compilation cache *** Because the `out_axes_thunk`s are now arguments to a _global_ compilation cache (in the form of `lu.cache` decorator on `parallel_callable`), we have to ensure that they implement `hash` and `==`. This is what forces us to add some slightly weird helpers such as `_hashable_function` and `_ignore_elem_list`. The code that uses those makes an assumption that the output pytree depends deterministically on the identity of the wrapped function, which I think is in line with general JAX assumptions. Otherwise the cache would depend on the identity of the thunk, which changes with every function invocation. Relaxing the global constraint on the cache (e.g. allowing each `pmap(f)` instance to have a separate cache) would make this easier too. * Why final style? * Now, making the primitives initial-style would remove the necessity for thunking, because we could have obtained the output pytree right when the function is wrapped. I assumed there is a good argument for making `pmap` pretend that it's a final-style primitive, but I'm not sure why that is? I hope it's something better than just avoiding a single jaxpr tracing.
2020-11-09 17:23:16 +00:00
mapped_shapes = [(3, 4), (3, 1), (1, 4)]
arg_shapes = map(partial(add_bdim, pmapped_size), in_axes, mapped_shapes)
rng = jtu.rand_default(self.rng())
args = [rng(shape, jnp.float64) for shape in arg_shapes]
jtu.check_grads(pmap(f, in_axes=in_axes, out_axes=out_axes), args,
order=2, atol=2e-2, rtol=2e-2, eps=1e-3)
def testPmapPostProcess(self):
def mk_case(map_fun):
def f(x, y):
# NOTE: Map doesn't have any arguments we differentiate wrt
@partial(map_fun, in_axes=1, out_axes=2)
def h(y):
return jnp.sin(x + y)
return h(y).sum()
return f
xshape = (5, 7)
x = np.arange(prod(xshape), dtype=np.float32).reshape(xshape)
yshape = (5, jax.device_count(), 7)
Add support for non-zero (but still not-None) out_axes in pmap Previously `pmap` didn't have the `out_axes` parameter (unlike `vmap`), but its semantics would match the specification of `out_axes=0` (i.e. all outputs should be stacked along the first axis). This patch makes it possible to specify non-zero values for out_axes, but more importantly it lays down the groundwork for `xmap` which will have to use some extremely similar (if not the same) code paths. One thing to note is that when I started this implementation I was also planning to add support for `out_axes=None`, which would allow us to stop using the `unbroadcast` hack, and most of the code is written with that in mind. Unfortunately it turned out that the correct implementation of the transpose rule for maps that do allow unmapped outputs would require me to pretty much simulate what avals-with-names is supposed to achieve. Technically replicated outputs should work today, for as long as the user does not do reverse-mode AD of `pmap`. But I decided that it's better to just disable them altogether until we can get the full and correct behavior. * Implementation details * This patch is significantly more involved than the one that implemented general `in_axes` support. That previous one at least had the foundation of `mapped_invars` which already behaved pretty similarly to general `in_axes`. From a quick glance one might think that `out_axes` should behave similarly to `in_axes`, but it turns out that this is not the case, at least not if we're interested in keeping those primitives final-style. ** Thunking ** The biggest difficulty with handling `out_axes` in final style primitives is that we want to treat them as a prefix of the output pytree, but we don't know the structure of the output pytree until the user function is evaluated! And the user function is not evaluated until we've applied all transforms and reached the impl rule! The solution to this problem is "straightforward": instead of putting `out_axes` as a primitive parameter, we bundle an `out_axes_thunk` which can only be called successfully after the wrapped function has been executed. The thunk returns a list of flat `out_axes`, expanded to the output pytree. However, the thunking presents us with two problems: *** Transformations *** Each transformation that modifies the number of outputs needs to ensure that the thunk is updated to reflect the new values. To make things worse a lot of the transforms can learn the number of added outputs _only after the wrapped function is evaluated_, which leads to the following "time travel" pattern that can be found in most `Trace`s: ```py @lu.transformation_with_aux def compute_output_statistic(*args, **kwargs): outputs = yield args, kwargs yield outputs, compute_statistic(outputs) wrapped_fun, output_statistic = compute_output_statistic(wrapped_fun) def new_out_axes_thunk(): old_out_axes = params['out_axes_thunk']() return compute_new_out_axes(old_out_axes(), output_statistic()) primitive.bind(wrapped_fun, dict(params, out_axes_thunk=new_out_axes_thunk)) ``` The reason why we have to structure the code this way is that we can only specify a new `out_axes_thunk` before we bind the primitive, but we need the outputs of bind to know how to update the `out_axes_thunk`. To make things worse, the implementation of `bind` is allowed to make a call to `out_axes_thunk` _immediately after `wrapped_fun` is evaluated_. This means that we cannot compute the output statistic in the implementation of the transformation, but we have to use an extra `lu.transformation_with_aux` for that (this populates the statistic store immediately after `wrapped_fun` is evaluated). The `compute_statistic` function depends on the transform in question. E.g. in the JVP trace it counts the number of non-zero tangent results. The situation is of course further complicated when we take `post_process_map` into account. The new `process_env_traces` now always sets up this funny time travel trampoline just in case it ends up being necessary, and `post_process_map` is now expected to return `(outputs, (todo, out_axes_transform))` instead of just `(outputs, todo)`. *** Compilation cache *** Because the `out_axes_thunk`s are now arguments to a _global_ compilation cache (in the form of `lu.cache` decorator on `parallel_callable`), we have to ensure that they implement `hash` and `==`. This is what forces us to add some slightly weird helpers such as `_hashable_function` and `_ignore_elem_list`. The code that uses those makes an assumption that the output pytree depends deterministically on the identity of the wrapped function, which I think is in line with general JAX assumptions. Otherwise the cache would depend on the identity of the thunk, which changes with every function invocation. Relaxing the global constraint on the cache (e.g. allowing each `pmap(f)` instance to have a separate cache) would make this easier too. * Why final style? * Now, making the primitives initial-style would remove the necessity for thunking, because we could have obtained the output pytree right when the function is wrapped. I assumed there is a good argument for making `pmap` pretend that it's a final-style primitive, but I'm not sure why that is? I hope it's something better than just avoiding a single jaxpr tracing.
2020-11-09 17:23:16 +00:00
y = np.arange(prod(yshape), dtype=np.float32).reshape(yshape)
self.assertAllClose(jax.grad(mk_case(pmap))(x, y),
jax.grad(mk_case(vmap))(x, y))
class ShardedDeviceArrayTest(jtu.JaxTestCase):
def testThreadsafeIndexing(self):
# NOTE(skye): I picked these values to be big enough to cause interesting
# execution overlap, but small enough to not use too much memory. YMMV.
shape = (8, 8000, 1000)
if jax.device_count() < shape[0]:
raise SkipTest(f"requires {shape[0]} devices")
x = jnp.arange(prod(shape)).reshape(shape)
sharded_x = pmap(lambda x: x)(x)
num_threads = 10
futures = []
expected = []
with ThreadPoolExecutor(max_workers=num_threads) as executor:
for i in range(num_threads):
idx = i % shape[0]
# Mix together different kinds of indices
if i % 2 == 0:
idx = slice(idx, idx + 1)
# Use the "kwarg trick" to work around late-binding closures. See
# https://docs.python-guide.org/writing/gotchas/#late-binding-closures.
futures.append(executor.submit(
lambda idx=idx: [sharded_x[idx] for _ in range(10)][0]))
expected.append(x[idx])
actual = [f.result() for f in futures]
self.assertAllClose(actual, expected, check_dtypes=False)
def testNoCopyIndexing1D(self):
shape = (8, 4)
if jax.device_count() < shape[0]:
raise SkipTest(f"requires {shape[0]} devices")
x = jnp.arange(prod(shape)).reshape(shape)
sharded_x = pmap(lambda x: x)(x)
self.assertIsNone(sharded_x._npy_value)
for i in range(8):
self.assertIsInstance(sharded_x[i], device_array.DeviceArray)
self.assertIsNone(sharded_x._npy_value)
def test_device_put_sharded_array(self):
devices = jax.local_devices()
n_devices = len(devices)
x = [np.arange(i, i + 4) for i in range(n_devices)]
y = jax.device_put_sharded(x, devices)
self.assertIsInstance(y, pxla.ShardedDeviceArray)
self.assertEqual(len(y.device_buffers), len(devices))
self.assertTrue(all(b.device() == d for b, d in zip(y.device_buffers, devices)))
2020-12-04 17:07:23 -08:00
self.assertArraysEqual(y, jnp.stack(x))
def test_device_put_sharded_pytree(self):
devices = jax.local_devices()
n_devices = len(devices)
x = [(i, np.arange(i, i + 4)) for i in range(n_devices)]
y1, y2 = jax.device_put_sharded(x, devices)
self.assertIsInstance(y1, pxla.ShardedDeviceArray)
2020-12-04 17:07:23 -08:00
self.assertArraysEqual(y1, jnp.array([a for a, _ in x]))
self.assertTrue(all(b.device() == d for b, d in zip(y1.device_buffers, devices)))
self.assertIsInstance(y2, pxla.ShardedDeviceArray)
2020-12-04 17:07:23 -08:00
self.assertArraysEqual(y2, jnp.vstack([b for _, b in x]))
self.assertTrue(all(b.device() == d for b, d in zip(y2.device_buffers, devices)))
def test_device_put_replicated_array(self):
devices = jax.local_devices()
x = np.arange(1, 5)
y = jax.device_put_replicated(x, devices)
self.assertIsInstance(y, pxla.ShardedDeviceArray)
self.assertEqual(len(y.device_buffers), len(devices))
self.assertTrue(all(b.device() == d for b, d in zip(y.device_buffers, devices)))
2020-12-04 17:07:23 -08:00
self.assertArraysEqual(y, np.stack([x for _ in devices]))
def test_device_put_replicated_pytree(self):
devices = jax.local_devices()
xs = {'a': np.arange(1, 5), 'b': np.arange(3)}
ys = jax.device_put_replicated(xs, devices)
self.assertIsInstance(ys, dict)
y1, y2 = ys['a'], ys['b']
self.assertIsInstance(y1, pxla.ShardedDeviceArray)
self.assertEqual(len(y1.device_buffers), len(devices))
self.assertTrue(all(b.device() == d for b, d in zip(y1.device_buffers, devices)))
2020-12-04 17:07:23 -08:00
self.assertArraysEqual(y1, np.stack([xs['a'] for _ in devices]))
self.assertIsInstance(y2, pxla.ShardedDeviceArray)
self.assertEqual(len(y2.device_buffers), len(devices))
self.assertTrue(all(b.device() == d for b, d in zip(y2.device_buffers, devices)))
2020-12-04 17:07:23 -08:00
self.assertArraysEqual(y2, np.stack([xs['b'] for _ in devices]))
def test_repr(self):
x = jax.device_put_replicated(1, jax.devices())
self.assertStartsWith(repr(x), 'ShardedDeviceArray')
def test_delete_is_idempotent(self):
x = jax.device_put_replicated(1, jax.devices())
x.delete()
x.delete()
with self.assertRaisesRegex(ValueError,
'ShardedDeviceArray has been deleted.'):
_ = x[0]
Allow ShardedDeviceArrays to represent arbitrary data shardings. (#2142) This change introduces ShardingSpec, a struct describing how an array should be sharded. This is integrated into ShardedDeviceArray to allow more flexible sharding. It supports partitioning (both "pmap-style", where an entire axis is decomposed into separate shards and doesn't appear in the on-device shape at all, and "sharded_jit-style", where an axis is chunked into shards but remains in the on-device shape) and replication. This removes the need for ChunkedDeviceArray, since a ShardedDeviceArray can now represent chunks. Here are pmap_benchmark times showing that the overall effect of this change neutral to positive (integer indexing is much faster!). **pmap_shard_args** ``` ---------Benchmark summary for pmap_shard_args--------- nargs nshards mean %std relative mean/baseline ------- --------- --------- --------- ---------- --------------- 10 8 0.041855 4.15223 1 1.01466 100 8 0.129884 4.85321 3.1032 0.988543 101 8 0.136347 6.20233 3.2576 0.967138 500 8 0.533207 3.6815 12.7394 1.0294 1000 8 1.10338 0.525193 26.362 0.960435 5000 8 5.33911 0 127.562 0.963319 100 2 0.0638619 10.7069 1.52579 1.0362 100 4 0.0868253 6.76701 2.07443 0.967323 100 8 0.128151 6.46004 3.06177 0.979742 100 100 1.22631 1.94885 29.299 1.00371 100 500 6.60746 0 157.865 0.956657 ``` **pmap_shard_outputs** ``` nouts nshards mean %std relative mean/baseline ------- --------- ---------- --------- ---------- --------------- 10 8 0.0664526 9.49251 1 0.938466 100 8 0.195711 2.19429 2.94512 1.04239 500 8 0.82577 0.330864 12.4265 0.994669 1000 8 1.68323 1.0516 25.3298 0.966915 5000 8 8.89032 0 133.784 0.998038 100 2 0.074806 10.1734 1.12571 0.980254 100 4 0.121334 5.76774 1.82588 1.02033 100 8 0.185253 5.45068 2.78775 1.01666 100 100 2.37076 0 35.6759 1.08629 100 500 17.0832 0 257.074 0.976879 ``` **ShardedDeviceArray_indexing** ``` indices_fn mean %std relative mean/baseline ------------------ ---------- ------- ---------- --------------- integer_indices 0.0603473 8.29159 1 0.359496 integer_2D_indices 18.0241 0 298.672 1.00583 ``` This is how I ran the benchmark: ``` TARGET_TOTAL_SECS=2 CUDA_VISIBLE_DEVICES= XLA_FLAGS=--xla_force_host_platform_device_count=500 python3 benchmarks/pmap_benchmark.py --baseline_dir=<results as of a3cc9a7> ```
2020-04-15 12:43:55 -07:00
class SpecToIndicesTest(jtu.JaxTestCase):
def testShardsPerAxis(self):
shape = (4, 8)
spec = pxla.ShardingSpec(sharding=map(pxla.Chunked, ([2], [2])),
mesh_mapping=map(pxla.ShardedAxis, (0, 1)))
Allow ShardedDeviceArrays to represent arbitrary data shardings. (#2142) This change introduces ShardingSpec, a struct describing how an array should be sharded. This is integrated into ShardedDeviceArray to allow more flexible sharding. It supports partitioning (both "pmap-style", where an entire axis is decomposed into separate shards and doesn't appear in the on-device shape at all, and "sharded_jit-style", where an axis is chunked into shards but remains in the on-device shape) and replication. This removes the need for ChunkedDeviceArray, since a ShardedDeviceArray can now represent chunks. Here are pmap_benchmark times showing that the overall effect of this change neutral to positive (integer indexing is much faster!). **pmap_shard_args** ``` ---------Benchmark summary for pmap_shard_args--------- nargs nshards mean %std relative mean/baseline ------- --------- --------- --------- ---------- --------------- 10 8 0.041855 4.15223 1 1.01466 100 8 0.129884 4.85321 3.1032 0.988543 101 8 0.136347 6.20233 3.2576 0.967138 500 8 0.533207 3.6815 12.7394 1.0294 1000 8 1.10338 0.525193 26.362 0.960435 5000 8 5.33911 0 127.562 0.963319 100 2 0.0638619 10.7069 1.52579 1.0362 100 4 0.0868253 6.76701 2.07443 0.967323 100 8 0.128151 6.46004 3.06177 0.979742 100 100 1.22631 1.94885 29.299 1.00371 100 500 6.60746 0 157.865 0.956657 ``` **pmap_shard_outputs** ``` nouts nshards mean %std relative mean/baseline ------- --------- ---------- --------- ---------- --------------- 10 8 0.0664526 9.49251 1 0.938466 100 8 0.195711 2.19429 2.94512 1.04239 500 8 0.82577 0.330864 12.4265 0.994669 1000 8 1.68323 1.0516 25.3298 0.966915 5000 8 8.89032 0 133.784 0.998038 100 2 0.074806 10.1734 1.12571 0.980254 100 4 0.121334 5.76774 1.82588 1.02033 100 8 0.185253 5.45068 2.78775 1.01666 100 100 2.37076 0 35.6759 1.08629 100 500 17.0832 0 257.074 0.976879 ``` **ShardedDeviceArray_indexing** ``` indices_fn mean %std relative mean/baseline ------------------ ---------- ------- ---------- --------------- integer_indices 0.0603473 8.29159 1 0.359496 integer_2D_indices 18.0241 0 298.672 1.00583 ``` This is how I ran the benchmark: ``` TARGET_TOTAL_SECS=2 CUDA_VISIBLE_DEVICES= XLA_FLAGS=--xla_force_host_platform_device_count=500 python3 benchmarks/pmap_benchmark.py --baseline_dir=<results as of a3cc9a7> ```
2020-04-15 12:43:55 -07:00
self.assertEqual(pxla.spec_to_indices(shape, spec),
((slice(0,2), slice(0,4)),
(slice(0,2), slice(4,8)),
(slice(2,4), slice(0,4)),
(slice(2,4), slice(4,8))))
def testShardedAxisPermutation(self):
shape = (4, 8)
spec = pxla.ShardingSpec(sharding=map(pxla.Chunked, ([2], [2])),
mesh_mapping=map(pxla.ShardedAxis, (1, 0)))
self.assertEqual(pxla.spec_to_indices(shape, spec),
((slice(0,2), slice(0,4)),
(slice(2,4), slice(0,4)),
(slice(0,2), slice(4,8)),
(slice(2,4), slice(4,8))))
def testShardedAxisPermutationAndReplication(self):
shape = (4, 8)
spec = pxla.ShardingSpec(sharding=map(pxla.Chunked, ([2], [2])),
mesh_mapping=(pxla.Replicated(2),
pxla.ShardedAxis(1),
pxla.ShardedAxis(0)))
self.assertEqual(pxla.spec_to_indices(shape, spec),
((slice(0,2), slice(0,4)),
(slice(2,4), slice(0,4)),
(slice(0,2), slice(4,8)),
(slice(2,4), slice(4,8))) * 2)
Allow ShardedDeviceArrays to represent arbitrary data shardings. (#2142) This change introduces ShardingSpec, a struct describing how an array should be sharded. This is integrated into ShardedDeviceArray to allow more flexible sharding. It supports partitioning (both "pmap-style", where an entire axis is decomposed into separate shards and doesn't appear in the on-device shape at all, and "sharded_jit-style", where an axis is chunked into shards but remains in the on-device shape) and replication. This removes the need for ChunkedDeviceArray, since a ShardedDeviceArray can now represent chunks. Here are pmap_benchmark times showing that the overall effect of this change neutral to positive (integer indexing is much faster!). **pmap_shard_args** ``` ---------Benchmark summary for pmap_shard_args--------- nargs nshards mean %std relative mean/baseline ------- --------- --------- --------- ---------- --------------- 10 8 0.041855 4.15223 1 1.01466 100 8 0.129884 4.85321 3.1032 0.988543 101 8 0.136347 6.20233 3.2576 0.967138 500 8 0.533207 3.6815 12.7394 1.0294 1000 8 1.10338 0.525193 26.362 0.960435 5000 8 5.33911 0 127.562 0.963319 100 2 0.0638619 10.7069 1.52579 1.0362 100 4 0.0868253 6.76701 2.07443 0.967323 100 8 0.128151 6.46004 3.06177 0.979742 100 100 1.22631 1.94885 29.299 1.00371 100 500 6.60746 0 157.865 0.956657 ``` **pmap_shard_outputs** ``` nouts nshards mean %std relative mean/baseline ------- --------- ---------- --------- ---------- --------------- 10 8 0.0664526 9.49251 1 0.938466 100 8 0.195711 2.19429 2.94512 1.04239 500 8 0.82577 0.330864 12.4265 0.994669 1000 8 1.68323 1.0516 25.3298 0.966915 5000 8 8.89032 0 133.784 0.998038 100 2 0.074806 10.1734 1.12571 0.980254 100 4 0.121334 5.76774 1.82588 1.02033 100 8 0.185253 5.45068 2.78775 1.01666 100 100 2.37076 0 35.6759 1.08629 100 500 17.0832 0 257.074 0.976879 ``` **ShardedDeviceArray_indexing** ``` indices_fn mean %std relative mean/baseline ------------------ ---------- ------- ---------- --------------- integer_indices 0.0603473 8.29159 1 0.359496 integer_2D_indices 18.0241 0 298.672 1.00583 ``` This is how I ran the benchmark: ``` TARGET_TOTAL_SECS=2 CUDA_VISIBLE_DEVICES= XLA_FLAGS=--xla_force_host_platform_device_count=500 python3 benchmarks/pmap_benchmark.py --baseline_dir=<results as of a3cc9a7> ```
2020-04-15 12:43:55 -07:00
def testUnshardedAxis(self):
shape = (4, 8)
spec = pxla.ShardingSpec(sharding=(pxla.Chunked([2]), pxla.NoSharding()),
mesh_mapping=(pxla.ShardedAxis(0),))
Allow ShardedDeviceArrays to represent arbitrary data shardings. (#2142) This change introduces ShardingSpec, a struct describing how an array should be sharded. This is integrated into ShardedDeviceArray to allow more flexible sharding. It supports partitioning (both "pmap-style", where an entire axis is decomposed into separate shards and doesn't appear in the on-device shape at all, and "sharded_jit-style", where an axis is chunked into shards but remains in the on-device shape) and replication. This removes the need for ChunkedDeviceArray, since a ShardedDeviceArray can now represent chunks. Here are pmap_benchmark times showing that the overall effect of this change neutral to positive (integer indexing is much faster!). **pmap_shard_args** ``` ---------Benchmark summary for pmap_shard_args--------- nargs nshards mean %std relative mean/baseline ------- --------- --------- --------- ---------- --------------- 10 8 0.041855 4.15223 1 1.01466 100 8 0.129884 4.85321 3.1032 0.988543 101 8 0.136347 6.20233 3.2576 0.967138 500 8 0.533207 3.6815 12.7394 1.0294 1000 8 1.10338 0.525193 26.362 0.960435 5000 8 5.33911 0 127.562 0.963319 100 2 0.0638619 10.7069 1.52579 1.0362 100 4 0.0868253 6.76701 2.07443 0.967323 100 8 0.128151 6.46004 3.06177 0.979742 100 100 1.22631 1.94885 29.299 1.00371 100 500 6.60746 0 157.865 0.956657 ``` **pmap_shard_outputs** ``` nouts nshards mean %std relative mean/baseline ------- --------- ---------- --------- ---------- --------------- 10 8 0.0664526 9.49251 1 0.938466 100 8 0.195711 2.19429 2.94512 1.04239 500 8 0.82577 0.330864 12.4265 0.994669 1000 8 1.68323 1.0516 25.3298 0.966915 5000 8 8.89032 0 133.784 0.998038 100 2 0.074806 10.1734 1.12571 0.980254 100 4 0.121334 5.76774 1.82588 1.02033 100 8 0.185253 5.45068 2.78775 1.01666 100 100 2.37076 0 35.6759 1.08629 100 500 17.0832 0 257.074 0.976879 ``` **ShardedDeviceArray_indexing** ``` indices_fn mean %std relative mean/baseline ------------------ ---------- ------- ---------- --------------- integer_indices 0.0603473 8.29159 1 0.359496 integer_2D_indices 18.0241 0 298.672 1.00583 ``` This is how I ran the benchmark: ``` TARGET_TOTAL_SECS=2 CUDA_VISIBLE_DEVICES= XLA_FLAGS=--xla_force_host_platform_device_count=500 python3 benchmarks/pmap_benchmark.py --baseline_dir=<results as of a3cc9a7> ```
2020-04-15 12:43:55 -07:00
self.assertEqual(pxla.spec_to_indices(shape, spec),
((slice(0,2), slice(None)),
(slice(2,4), slice(None))))
Allow ShardedDeviceArrays to represent arbitrary data shardings. (#2142) This change introduces ShardingSpec, a struct describing how an array should be sharded. This is integrated into ShardedDeviceArray to allow more flexible sharding. It supports partitioning (both "pmap-style", where an entire axis is decomposed into separate shards and doesn't appear in the on-device shape at all, and "sharded_jit-style", where an axis is chunked into shards but remains in the on-device shape) and replication. This removes the need for ChunkedDeviceArray, since a ShardedDeviceArray can now represent chunks. Here are pmap_benchmark times showing that the overall effect of this change neutral to positive (integer indexing is much faster!). **pmap_shard_args** ``` ---------Benchmark summary for pmap_shard_args--------- nargs nshards mean %std relative mean/baseline ------- --------- --------- --------- ---------- --------------- 10 8 0.041855 4.15223 1 1.01466 100 8 0.129884 4.85321 3.1032 0.988543 101 8 0.136347 6.20233 3.2576 0.967138 500 8 0.533207 3.6815 12.7394 1.0294 1000 8 1.10338 0.525193 26.362 0.960435 5000 8 5.33911 0 127.562 0.963319 100 2 0.0638619 10.7069 1.52579 1.0362 100 4 0.0868253 6.76701 2.07443 0.967323 100 8 0.128151 6.46004 3.06177 0.979742 100 100 1.22631 1.94885 29.299 1.00371 100 500 6.60746 0 157.865 0.956657 ``` **pmap_shard_outputs** ``` nouts nshards mean %std relative mean/baseline ------- --------- ---------- --------- ---------- --------------- 10 8 0.0664526 9.49251 1 0.938466 100 8 0.195711 2.19429 2.94512 1.04239 500 8 0.82577 0.330864 12.4265 0.994669 1000 8 1.68323 1.0516 25.3298 0.966915 5000 8 8.89032 0 133.784 0.998038 100 2 0.074806 10.1734 1.12571 0.980254 100 4 0.121334 5.76774 1.82588 1.02033 100 8 0.185253 5.45068 2.78775 1.01666 100 100 2.37076 0 35.6759 1.08629 100 500 17.0832 0 257.074 0.976879 ``` **ShardedDeviceArray_indexing** ``` indices_fn mean %std relative mean/baseline ------------------ ---------- ------- ---------- --------------- integer_indices 0.0603473 8.29159 1 0.359496 integer_2D_indices 18.0241 0 298.672 1.00583 ``` This is how I ran the benchmark: ``` TARGET_TOTAL_SECS=2 CUDA_VISIBLE_DEVICES= XLA_FLAGS=--xla_force_host_platform_device_count=500 python3 benchmarks/pmap_benchmark.py --baseline_dir=<results as of a3cc9a7> ```
2020-04-15 12:43:55 -07:00
def testNoSharding(self):
shape = (4, 8)
spec = pxla.ShardingSpec(sharding=(pxla.NoSharding(), pxla.NoSharding()),
mesh_mapping=())
Allow ShardedDeviceArrays to represent arbitrary data shardings. (#2142) This change introduces ShardingSpec, a struct describing how an array should be sharded. This is integrated into ShardedDeviceArray to allow more flexible sharding. It supports partitioning (both "pmap-style", where an entire axis is decomposed into separate shards and doesn't appear in the on-device shape at all, and "sharded_jit-style", where an axis is chunked into shards but remains in the on-device shape) and replication. This removes the need for ChunkedDeviceArray, since a ShardedDeviceArray can now represent chunks. Here are pmap_benchmark times showing that the overall effect of this change neutral to positive (integer indexing is much faster!). **pmap_shard_args** ``` ---------Benchmark summary for pmap_shard_args--------- nargs nshards mean %std relative mean/baseline ------- --------- --------- --------- ---------- --------------- 10 8 0.041855 4.15223 1 1.01466 100 8 0.129884 4.85321 3.1032 0.988543 101 8 0.136347 6.20233 3.2576 0.967138 500 8 0.533207 3.6815 12.7394 1.0294 1000 8 1.10338 0.525193 26.362 0.960435 5000 8 5.33911 0 127.562 0.963319 100 2 0.0638619 10.7069 1.52579 1.0362 100 4 0.0868253 6.76701 2.07443 0.967323 100 8 0.128151 6.46004 3.06177 0.979742 100 100 1.22631 1.94885 29.299 1.00371 100 500 6.60746 0 157.865 0.956657 ``` **pmap_shard_outputs** ``` nouts nshards mean %std relative mean/baseline ------- --------- ---------- --------- ---------- --------------- 10 8 0.0664526 9.49251 1 0.938466 100 8 0.195711 2.19429 2.94512 1.04239 500 8 0.82577 0.330864 12.4265 0.994669 1000 8 1.68323 1.0516 25.3298 0.966915 5000 8 8.89032 0 133.784 0.998038 100 2 0.074806 10.1734 1.12571 0.980254 100 4 0.121334 5.76774 1.82588 1.02033 100 8 0.185253 5.45068 2.78775 1.01666 100 100 2.37076 0 35.6759 1.08629 100 500 17.0832 0 257.074 0.976879 ``` **ShardedDeviceArray_indexing** ``` indices_fn mean %std relative mean/baseline ------------------ ---------- ------- ---------- --------------- integer_indices 0.0603473 8.29159 1 0.359496 integer_2D_indices 18.0241 0 298.672 1.00583 ``` This is how I ran the benchmark: ``` TARGET_TOTAL_SECS=2 CUDA_VISIBLE_DEVICES= XLA_FLAGS=--xla_force_host_platform_device_count=500 python3 benchmarks/pmap_benchmark.py --baseline_dir=<results as of a3cc9a7> ```
2020-04-15 12:43:55 -07:00
self.assertEqual(pxla.spec_to_indices(shape, spec),
((slice(None), slice(None)),))
Allow ShardedDeviceArrays to represent arbitrary data shardings. (#2142) This change introduces ShardingSpec, a struct describing how an array should be sharded. This is integrated into ShardedDeviceArray to allow more flexible sharding. It supports partitioning (both "pmap-style", where an entire axis is decomposed into separate shards and doesn't appear in the on-device shape at all, and "sharded_jit-style", where an axis is chunked into shards but remains in the on-device shape) and replication. This removes the need for ChunkedDeviceArray, since a ShardedDeviceArray can now represent chunks. Here are pmap_benchmark times showing that the overall effect of this change neutral to positive (integer indexing is much faster!). **pmap_shard_args** ``` ---------Benchmark summary for pmap_shard_args--------- nargs nshards mean %std relative mean/baseline ------- --------- --------- --------- ---------- --------------- 10 8 0.041855 4.15223 1 1.01466 100 8 0.129884 4.85321 3.1032 0.988543 101 8 0.136347 6.20233 3.2576 0.967138 500 8 0.533207 3.6815 12.7394 1.0294 1000 8 1.10338 0.525193 26.362 0.960435 5000 8 5.33911 0 127.562 0.963319 100 2 0.0638619 10.7069 1.52579 1.0362 100 4 0.0868253 6.76701 2.07443 0.967323 100 8 0.128151 6.46004 3.06177 0.979742 100 100 1.22631 1.94885 29.299 1.00371 100 500 6.60746 0 157.865 0.956657 ``` **pmap_shard_outputs** ``` nouts nshards mean %std relative mean/baseline ------- --------- ---------- --------- ---------- --------------- 10 8 0.0664526 9.49251 1 0.938466 100 8 0.195711 2.19429 2.94512 1.04239 500 8 0.82577 0.330864 12.4265 0.994669 1000 8 1.68323 1.0516 25.3298 0.966915 5000 8 8.89032 0 133.784 0.998038 100 2 0.074806 10.1734 1.12571 0.980254 100 4 0.121334 5.76774 1.82588 1.02033 100 8 0.185253 5.45068 2.78775 1.01666 100 100 2.37076 0 35.6759 1.08629 100 500 17.0832 0 257.074 0.976879 ``` **ShardedDeviceArray_indexing** ``` indices_fn mean %std relative mean/baseline ------------------ ---------- ------- ---------- --------------- integer_indices 0.0603473 8.29159 1 0.359496 integer_2D_indices 18.0241 0 298.672 1.00583 ``` This is how I ran the benchmark: ``` TARGET_TOTAL_SECS=2 CUDA_VISIBLE_DEVICES= XLA_FLAGS=--xla_force_host_platform_device_count=500 python3 benchmarks/pmap_benchmark.py --baseline_dir=<results as of a3cc9a7> ```
2020-04-15 12:43:55 -07:00
def testUnmaterializedAxis(self):
shape = (4, 8)
spec = pxla.ShardingSpec(sharding=(pxla.Unstacked(4), pxla.NoSharding()),
mesh_mapping=(pxla.ShardedAxis(0),))
Allow ShardedDeviceArrays to represent arbitrary data shardings. (#2142) This change introduces ShardingSpec, a struct describing how an array should be sharded. This is integrated into ShardedDeviceArray to allow more flexible sharding. It supports partitioning (both "pmap-style", where an entire axis is decomposed into separate shards and doesn't appear in the on-device shape at all, and "sharded_jit-style", where an axis is chunked into shards but remains in the on-device shape) and replication. This removes the need for ChunkedDeviceArray, since a ShardedDeviceArray can now represent chunks. Here are pmap_benchmark times showing that the overall effect of this change neutral to positive (integer indexing is much faster!). **pmap_shard_args** ``` ---------Benchmark summary for pmap_shard_args--------- nargs nshards mean %std relative mean/baseline ------- --------- --------- --------- ---------- --------------- 10 8 0.041855 4.15223 1 1.01466 100 8 0.129884 4.85321 3.1032 0.988543 101 8 0.136347 6.20233 3.2576 0.967138 500 8 0.533207 3.6815 12.7394 1.0294 1000 8 1.10338 0.525193 26.362 0.960435 5000 8 5.33911 0 127.562 0.963319 100 2 0.0638619 10.7069 1.52579 1.0362 100 4 0.0868253 6.76701 2.07443 0.967323 100 8 0.128151 6.46004 3.06177 0.979742 100 100 1.22631 1.94885 29.299 1.00371 100 500 6.60746 0 157.865 0.956657 ``` **pmap_shard_outputs** ``` nouts nshards mean %std relative mean/baseline ------- --------- ---------- --------- ---------- --------------- 10 8 0.0664526 9.49251 1 0.938466 100 8 0.195711 2.19429 2.94512 1.04239 500 8 0.82577 0.330864 12.4265 0.994669 1000 8 1.68323 1.0516 25.3298 0.966915 5000 8 8.89032 0 133.784 0.998038 100 2 0.074806 10.1734 1.12571 0.980254 100 4 0.121334 5.76774 1.82588 1.02033 100 8 0.185253 5.45068 2.78775 1.01666 100 100 2.37076 0 35.6759 1.08629 100 500 17.0832 0 257.074 0.976879 ``` **ShardedDeviceArray_indexing** ``` indices_fn mean %std relative mean/baseline ------------------ ---------- ------- ---------- --------------- integer_indices 0.0603473 8.29159 1 0.359496 integer_2D_indices 18.0241 0 298.672 1.00583 ``` This is how I ran the benchmark: ``` TARGET_TOTAL_SECS=2 CUDA_VISIBLE_DEVICES= XLA_FLAGS=--xla_force_host_platform_device_count=500 python3 benchmarks/pmap_benchmark.py --baseline_dir=<results as of a3cc9a7> ```
2020-04-15 12:43:55 -07:00
self.assertEqual(pxla.spec_to_indices(shape, spec),
((0, slice(None)),
(1, slice(None)),
(2, slice(None)),
(3, slice(None))))
Allow ShardedDeviceArrays to represent arbitrary data shardings. (#2142) This change introduces ShardingSpec, a struct describing how an array should be sharded. This is integrated into ShardedDeviceArray to allow more flexible sharding. It supports partitioning (both "pmap-style", where an entire axis is decomposed into separate shards and doesn't appear in the on-device shape at all, and "sharded_jit-style", where an axis is chunked into shards but remains in the on-device shape) and replication. This removes the need for ChunkedDeviceArray, since a ShardedDeviceArray can now represent chunks. Here are pmap_benchmark times showing that the overall effect of this change neutral to positive (integer indexing is much faster!). **pmap_shard_args** ``` ---------Benchmark summary for pmap_shard_args--------- nargs nshards mean %std relative mean/baseline ------- --------- --------- --------- ---------- --------------- 10 8 0.041855 4.15223 1 1.01466 100 8 0.129884 4.85321 3.1032 0.988543 101 8 0.136347 6.20233 3.2576 0.967138 500 8 0.533207 3.6815 12.7394 1.0294 1000 8 1.10338 0.525193 26.362 0.960435 5000 8 5.33911 0 127.562 0.963319 100 2 0.0638619 10.7069 1.52579 1.0362 100 4 0.0868253 6.76701 2.07443 0.967323 100 8 0.128151 6.46004 3.06177 0.979742 100 100 1.22631 1.94885 29.299 1.00371 100 500 6.60746 0 157.865 0.956657 ``` **pmap_shard_outputs** ``` nouts nshards mean %std relative mean/baseline ------- --------- ---------- --------- ---------- --------------- 10 8 0.0664526 9.49251 1 0.938466 100 8 0.195711 2.19429 2.94512 1.04239 500 8 0.82577 0.330864 12.4265 0.994669 1000 8 1.68323 1.0516 25.3298 0.966915 5000 8 8.89032 0 133.784 0.998038 100 2 0.074806 10.1734 1.12571 0.980254 100 4 0.121334 5.76774 1.82588 1.02033 100 8 0.185253 5.45068 2.78775 1.01666 100 100 2.37076 0 35.6759 1.08629 100 500 17.0832 0 257.074 0.976879 ``` **ShardedDeviceArray_indexing** ``` indices_fn mean %std relative mean/baseline ------------------ ---------- ------- ---------- --------------- integer_indices 0.0603473 8.29159 1 0.359496 integer_2D_indices 18.0241 0 298.672 1.00583 ``` This is how I ran the benchmark: ``` TARGET_TOTAL_SECS=2 CUDA_VISIBLE_DEVICES= XLA_FLAGS=--xla_force_host_platform_device_count=500 python3 benchmarks/pmap_benchmark.py --baseline_dir=<results as of a3cc9a7> ```
2020-04-15 12:43:55 -07:00
shape = (2, 2)
spec = pxla.ShardingSpec(sharding=(pxla.NoSharding(), pxla.Unstacked(2)),
mesh_mapping=(pxla.ShardedAxis(0),))
Allow ShardedDeviceArrays to represent arbitrary data shardings. (#2142) This change introduces ShardingSpec, a struct describing how an array should be sharded. This is integrated into ShardedDeviceArray to allow more flexible sharding. It supports partitioning (both "pmap-style", where an entire axis is decomposed into separate shards and doesn't appear in the on-device shape at all, and "sharded_jit-style", where an axis is chunked into shards but remains in the on-device shape) and replication. This removes the need for ChunkedDeviceArray, since a ShardedDeviceArray can now represent chunks. Here are pmap_benchmark times showing that the overall effect of this change neutral to positive (integer indexing is much faster!). **pmap_shard_args** ``` ---------Benchmark summary for pmap_shard_args--------- nargs nshards mean %std relative mean/baseline ------- --------- --------- --------- ---------- --------------- 10 8 0.041855 4.15223 1 1.01466 100 8 0.129884 4.85321 3.1032 0.988543 101 8 0.136347 6.20233 3.2576 0.967138 500 8 0.533207 3.6815 12.7394 1.0294 1000 8 1.10338 0.525193 26.362 0.960435 5000 8 5.33911 0 127.562 0.963319 100 2 0.0638619 10.7069 1.52579 1.0362 100 4 0.0868253 6.76701 2.07443 0.967323 100 8 0.128151 6.46004 3.06177 0.979742 100 100 1.22631 1.94885 29.299 1.00371 100 500 6.60746 0 157.865 0.956657 ``` **pmap_shard_outputs** ``` nouts nshards mean %std relative mean/baseline ------- --------- ---------- --------- ---------- --------------- 10 8 0.0664526 9.49251 1 0.938466 100 8 0.195711 2.19429 2.94512 1.04239 500 8 0.82577 0.330864 12.4265 0.994669 1000 8 1.68323 1.0516 25.3298 0.966915 5000 8 8.89032 0 133.784 0.998038 100 2 0.074806 10.1734 1.12571 0.980254 100 4 0.121334 5.76774 1.82588 1.02033 100 8 0.185253 5.45068 2.78775 1.01666 100 100 2.37076 0 35.6759 1.08629 100 500 17.0832 0 257.074 0.976879 ``` **ShardedDeviceArray_indexing** ``` indices_fn mean %std relative mean/baseline ------------------ ---------- ------- ---------- --------------- integer_indices 0.0603473 8.29159 1 0.359496 integer_2D_indices 18.0241 0 298.672 1.00583 ``` This is how I ran the benchmark: ``` TARGET_TOTAL_SECS=2 CUDA_VISIBLE_DEVICES= XLA_FLAGS=--xla_force_host_platform_device_count=500 python3 benchmarks/pmap_benchmark.py --baseline_dir=<results as of a3cc9a7> ```
2020-04-15 12:43:55 -07:00
self.assertEqual(pxla.spec_to_indices(shape, spec),
((slice(None), 0),
(slice(None), 1)))
def testReplicationAfterUnsharded(self):
Allow ShardedDeviceArrays to represent arbitrary data shardings. (#2142) This change introduces ShardingSpec, a struct describing how an array should be sharded. This is integrated into ShardedDeviceArray to allow more flexible sharding. It supports partitioning (both "pmap-style", where an entire axis is decomposed into separate shards and doesn't appear in the on-device shape at all, and "sharded_jit-style", where an axis is chunked into shards but remains in the on-device shape) and replication. This removes the need for ChunkedDeviceArray, since a ShardedDeviceArray can now represent chunks. Here are pmap_benchmark times showing that the overall effect of this change neutral to positive (integer indexing is much faster!). **pmap_shard_args** ``` ---------Benchmark summary for pmap_shard_args--------- nargs nshards mean %std relative mean/baseline ------- --------- --------- --------- ---------- --------------- 10 8 0.041855 4.15223 1 1.01466 100 8 0.129884 4.85321 3.1032 0.988543 101 8 0.136347 6.20233 3.2576 0.967138 500 8 0.533207 3.6815 12.7394 1.0294 1000 8 1.10338 0.525193 26.362 0.960435 5000 8 5.33911 0 127.562 0.963319 100 2 0.0638619 10.7069 1.52579 1.0362 100 4 0.0868253 6.76701 2.07443 0.967323 100 8 0.128151 6.46004 3.06177 0.979742 100 100 1.22631 1.94885 29.299 1.00371 100 500 6.60746 0 157.865 0.956657 ``` **pmap_shard_outputs** ``` nouts nshards mean %std relative mean/baseline ------- --------- ---------- --------- ---------- --------------- 10 8 0.0664526 9.49251 1 0.938466 100 8 0.195711 2.19429 2.94512 1.04239 500 8 0.82577 0.330864 12.4265 0.994669 1000 8 1.68323 1.0516 25.3298 0.966915 5000 8 8.89032 0 133.784 0.998038 100 2 0.074806 10.1734 1.12571 0.980254 100 4 0.121334 5.76774 1.82588 1.02033 100 8 0.185253 5.45068 2.78775 1.01666 100 100 2.37076 0 35.6759 1.08629 100 500 17.0832 0 257.074 0.976879 ``` **ShardedDeviceArray_indexing** ``` indices_fn mean %std relative mean/baseline ------------------ ---------- ------- ---------- --------------- integer_indices 0.0603473 8.29159 1 0.359496 integer_2D_indices 18.0241 0 298.672 1.00583 ``` This is how I ran the benchmark: ``` TARGET_TOTAL_SECS=2 CUDA_VISIBLE_DEVICES= XLA_FLAGS=--xla_force_host_platform_device_count=500 python3 benchmarks/pmap_benchmark.py --baseline_dir=<results as of a3cc9a7> ```
2020-04-15 12:43:55 -07:00
shape = (2, 8)
spec = pxla.ShardingSpec(sharding=(pxla.Unstacked(2), pxla.NoSharding()),
mesh_mapping=(pxla.ShardedAxis(0), pxla.Replicated(3)))
Allow ShardedDeviceArrays to represent arbitrary data shardings. (#2142) This change introduces ShardingSpec, a struct describing how an array should be sharded. This is integrated into ShardedDeviceArray to allow more flexible sharding. It supports partitioning (both "pmap-style", where an entire axis is decomposed into separate shards and doesn't appear in the on-device shape at all, and "sharded_jit-style", where an axis is chunked into shards but remains in the on-device shape) and replication. This removes the need for ChunkedDeviceArray, since a ShardedDeviceArray can now represent chunks. Here are pmap_benchmark times showing that the overall effect of this change neutral to positive (integer indexing is much faster!). **pmap_shard_args** ``` ---------Benchmark summary for pmap_shard_args--------- nargs nshards mean %std relative mean/baseline ------- --------- --------- --------- ---------- --------------- 10 8 0.041855 4.15223 1 1.01466 100 8 0.129884 4.85321 3.1032 0.988543 101 8 0.136347 6.20233 3.2576 0.967138 500 8 0.533207 3.6815 12.7394 1.0294 1000 8 1.10338 0.525193 26.362 0.960435 5000 8 5.33911 0 127.562 0.963319 100 2 0.0638619 10.7069 1.52579 1.0362 100 4 0.0868253 6.76701 2.07443 0.967323 100 8 0.128151 6.46004 3.06177 0.979742 100 100 1.22631 1.94885 29.299 1.00371 100 500 6.60746 0 157.865 0.956657 ``` **pmap_shard_outputs** ``` nouts nshards mean %std relative mean/baseline ------- --------- ---------- --------- ---------- --------------- 10 8 0.0664526 9.49251 1 0.938466 100 8 0.195711 2.19429 2.94512 1.04239 500 8 0.82577 0.330864 12.4265 0.994669 1000 8 1.68323 1.0516 25.3298 0.966915 5000 8 8.89032 0 133.784 0.998038 100 2 0.074806 10.1734 1.12571 0.980254 100 4 0.121334 5.76774 1.82588 1.02033 100 8 0.185253 5.45068 2.78775 1.01666 100 100 2.37076 0 35.6759 1.08629 100 500 17.0832 0 257.074 0.976879 ``` **ShardedDeviceArray_indexing** ``` indices_fn mean %std relative mean/baseline ------------------ ---------- ------- ---------- --------------- integer_indices 0.0603473 8.29159 1 0.359496 integer_2D_indices 18.0241 0 298.672 1.00583 ``` This is how I ran the benchmark: ``` TARGET_TOTAL_SECS=2 CUDA_VISIBLE_DEVICES= XLA_FLAGS=--xla_force_host_platform_device_count=500 python3 benchmarks/pmap_benchmark.py --baseline_dir=<results as of a3cc9a7> ```
2020-04-15 12:43:55 -07:00
self.assertEqual(pxla.spec_to_indices(shape, spec),
tuple([(0, slice(None))] * 3 + [(1, slice(None))] * 3))
Allow ShardedDeviceArrays to represent arbitrary data shardings. (#2142) This change introduces ShardingSpec, a struct describing how an array should be sharded. This is integrated into ShardedDeviceArray to allow more flexible sharding. It supports partitioning (both "pmap-style", where an entire axis is decomposed into separate shards and doesn't appear in the on-device shape at all, and "sharded_jit-style", where an axis is chunked into shards but remains in the on-device shape) and replication. This removes the need for ChunkedDeviceArray, since a ShardedDeviceArray can now represent chunks. Here are pmap_benchmark times showing that the overall effect of this change neutral to positive (integer indexing is much faster!). **pmap_shard_args** ``` ---------Benchmark summary for pmap_shard_args--------- nargs nshards mean %std relative mean/baseline ------- --------- --------- --------- ---------- --------------- 10 8 0.041855 4.15223 1 1.01466 100 8 0.129884 4.85321 3.1032 0.988543 101 8 0.136347 6.20233 3.2576 0.967138 500 8 0.533207 3.6815 12.7394 1.0294 1000 8 1.10338 0.525193 26.362 0.960435 5000 8 5.33911 0 127.562 0.963319 100 2 0.0638619 10.7069 1.52579 1.0362 100 4 0.0868253 6.76701 2.07443 0.967323 100 8 0.128151 6.46004 3.06177 0.979742 100 100 1.22631 1.94885 29.299 1.00371 100 500 6.60746 0 157.865 0.956657 ``` **pmap_shard_outputs** ``` nouts nshards mean %std relative mean/baseline ------- --------- ---------- --------- ---------- --------------- 10 8 0.0664526 9.49251 1 0.938466 100 8 0.195711 2.19429 2.94512 1.04239 500 8 0.82577 0.330864 12.4265 0.994669 1000 8 1.68323 1.0516 25.3298 0.966915 5000 8 8.89032 0 133.784 0.998038 100 2 0.074806 10.1734 1.12571 0.980254 100 4 0.121334 5.76774 1.82588 1.02033 100 8 0.185253 5.45068 2.78775 1.01666 100 100 2.37076 0 35.6759 1.08629 100 500 17.0832 0 257.074 0.976879 ``` **ShardedDeviceArray_indexing** ``` indices_fn mean %std relative mean/baseline ------------------ ---------- ------- ---------- --------------- integer_indices 0.0603473 8.29159 1 0.359496 integer_2D_indices 18.0241 0 298.672 1.00583 ``` This is how I ran the benchmark: ``` TARGET_TOTAL_SECS=2 CUDA_VISIBLE_DEVICES= XLA_FLAGS=--xla_force_host_platform_device_count=500 python3 benchmarks/pmap_benchmark.py --baseline_dir=<results as of a3cc9a7> ```
2020-04-15 12:43:55 -07:00
def testReplicationPosition2(self):
shape = (2, 8)
spec = pxla.ShardingSpec(sharding=(pxla.Unstacked(2), pxla.Chunked([2])),
mesh_mapping=(pxla.ShardedAxis(0), pxla.ShardedAxis(1), pxla.Replicated(3)))
self.assertEqual(pxla.spec_to_indices(shape, spec),
((0, slice(0, 4)), (0, slice(0, 4)), (0, slice(0, 4)),
(0, slice(4, 8)), (0, slice(4, 8)), (0, slice(4, 8)),
(1, slice(0, 4)), (1, slice(0, 4)), (1, slice(0, 4)),
(1, slice(4, 8)), (1, slice(4, 8)), (1, slice(4, 8))))
def testReplicationPosition1(self):
shape = (2, 8)
spec = pxla.ShardingSpec(sharding=(pxla.Unstacked(2), pxla.Chunked([2])),
mesh_mapping=(pxla.ShardedAxis(0), pxla.Replicated(3), pxla.ShardedAxis(1)))
self.assertEqual(pxla.spec_to_indices(shape, spec),
((0, slice(0, 4)), (0, slice(4, 8)),
(0, slice(0, 4)), (0, slice(4, 8)),
(0, slice(0, 4)), (0, slice(4, 8)),
(1, slice(0, 4)), (1, slice(4, 8)),
(1, slice(0, 4)), (1, slice(4, 8)),
(1, slice(0, 4)), (1, slice(4, 8))))
def testReplicationPosition0(self):
shape = (2, 8)
spec = pxla.ShardingSpec(sharding=(pxla.Unstacked(2), pxla.NoSharding()),
mesh_mapping=(pxla.Replicated(3), pxla.ShardedAxis(0)))
self.assertEqual(pxla.spec_to_indices(shape, spec),
tuple([(0, slice(None)), (1, slice(None))] * 3))
def testMultipleReplications(self):
shape = (2, 7, 4)
spec = pxla.ShardingSpec(
sharding=(pxla.Unstacked(2), pxla.NoSharding(), pxla.Chunked([2])),
mesh_mapping=(pxla.Replicated(3), pxla.Replicated(2),
pxla.ShardedAxis(0), pxla.Replicated(2),
pxla.ShardedAxis(1)))
self.assertEqual(
pxla.spec_to_indices(shape, spec),
((0, slice(None), slice(0, 2)), (0, slice(None), slice(2, 4)),
(0, slice(None), slice(0, 2)), (0, slice(None), slice(2, 4)),
(1, slice(None), slice(0, 2)), (1, slice(None), slice(2, 4)),
(1, slice(None), slice(0, 2)), (1, slice(None), slice(2, 4))) * 3 * 2)
def testReplicatedScalar(self):
shape = ()
spec = pxla.ShardingSpec(sharding=(),
mesh_mapping=(pxla.Replicated(3),))
self.assertEqual(pxla.spec_to_indices(shape, spec),
((), (), ()))
Allow ShardedDeviceArrays to represent arbitrary data shardings. (#2142) This change introduces ShardingSpec, a struct describing how an array should be sharded. This is integrated into ShardedDeviceArray to allow more flexible sharding. It supports partitioning (both "pmap-style", where an entire axis is decomposed into separate shards and doesn't appear in the on-device shape at all, and "sharded_jit-style", where an axis is chunked into shards but remains in the on-device shape) and replication. This removes the need for ChunkedDeviceArray, since a ShardedDeviceArray can now represent chunks. Here are pmap_benchmark times showing that the overall effect of this change neutral to positive (integer indexing is much faster!). **pmap_shard_args** ``` ---------Benchmark summary for pmap_shard_args--------- nargs nshards mean %std relative mean/baseline ------- --------- --------- --------- ---------- --------------- 10 8 0.041855 4.15223 1 1.01466 100 8 0.129884 4.85321 3.1032 0.988543 101 8 0.136347 6.20233 3.2576 0.967138 500 8 0.533207 3.6815 12.7394 1.0294 1000 8 1.10338 0.525193 26.362 0.960435 5000 8 5.33911 0 127.562 0.963319 100 2 0.0638619 10.7069 1.52579 1.0362 100 4 0.0868253 6.76701 2.07443 0.967323 100 8 0.128151 6.46004 3.06177 0.979742 100 100 1.22631 1.94885 29.299 1.00371 100 500 6.60746 0 157.865 0.956657 ``` **pmap_shard_outputs** ``` nouts nshards mean %std relative mean/baseline ------- --------- ---------- --------- ---------- --------------- 10 8 0.0664526 9.49251 1 0.938466 100 8 0.195711 2.19429 2.94512 1.04239 500 8 0.82577 0.330864 12.4265 0.994669 1000 8 1.68323 1.0516 25.3298 0.966915 5000 8 8.89032 0 133.784 0.998038 100 2 0.074806 10.1734 1.12571 0.980254 100 4 0.121334 5.76774 1.82588 1.02033 100 8 0.185253 5.45068 2.78775 1.01666 100 100 2.37076 0 35.6759 1.08629 100 500 17.0832 0 257.074 0.976879 ``` **ShardedDeviceArray_indexing** ``` indices_fn mean %std relative mean/baseline ------------------ ---------- ------- ---------- --------------- integer_indices 0.0603473 8.29159 1 0.359496 integer_2D_indices 18.0241 0 298.672 1.00583 ``` This is how I ran the benchmark: ``` TARGET_TOTAL_SECS=2 CUDA_VISIBLE_DEVICES= XLA_FLAGS=--xla_force_host_platform_device_count=500 python3 benchmarks/pmap_benchmark.py --baseline_dir=<results as of a3cc9a7> ```
2020-04-15 12:43:55 -07:00
def _spec_str(spec):
return (f"({spec.sharding},"
f"{spec.mesh_mapping},)")
class ShardArgsTest(jtu.JaxTestCase):
def numpy_array(x):
return x
def device_array(x):
return jax.device_put(x)
# TODO(skye): add coverage for ShardedDeviceArrays
@parameterized.named_parameters(
{"testcase_name":
f"_shape={shape}_spec={_spec_str(spec)}_arg={make_arg.__name__}"
.replace(" ", ""),
"shape": shape, "spec": spec, "make_arg": make_arg}
for make_arg in [numpy_array, device_array]
for shape, spec in [
# pmap(in_axes=0)
[(4, 8), pxla.ShardingSpec(sharding=(pxla.Unstacked(4), pxla.NoSharding()),
mesh_mapping=(pxla.ShardedAxis(0),))],
# pmap(in_axes=1)
[(2, 2), pxla.ShardingSpec(sharding=(pxla.NoSharding(), pxla.Unstacked(2)),
mesh_mapping=(pxla.ShardedAxis(0),))],
# unsharded
[(4, 8), pxla.ShardingSpec(sharding=(pxla.NoSharding(), pxla.NoSharding()),
mesh_mapping=())],
# partitioned, 1 axis
[(4, 8), pxla.ShardingSpec(sharding=(pxla.Chunked([2]), pxla.NoSharding()),
mesh_mapping=(pxla.ShardedAxis(0),))],
# partitioned, 2 axes
[(4, 8), pxla.ShardingSpec(sharding=(pxla.Chunked([2]), pxla.Chunked([2])),
mesh_mapping=map(pxla.ShardedAxis, (0, 1)))],
# partitioned, 2 axes, permuted
[(4, 8), pxla.ShardingSpec(sharding=(pxla.Chunked([2]), pxla.Chunked([2])),
mesh_mapping=map(pxla.ShardedAxis, (1, 0)))],
2020-05-21 13:52:03 -07:00
# partitioned + sharding
[(2, 8), pxla.ShardingSpec(sharding=(pxla.Unstacked(2), pxla.Chunked([2])),
mesh_mapping=map(pxla.ShardedAxis, (0, 1)))],
# replication + sharding
[(2, 8), pxla.ShardingSpec(sharding=(pxla.Unstacked(2), pxla.NoSharding()),
mesh_mapping=(pxla.ShardedAxis(0), pxla.Replicated(3)))],
# replication, no sharding
[(2, 8), pxla.ShardingSpec(sharding=(pxla.NoSharding(), pxla.NoSharding()),
mesh_mapping=(pxla.Replicated(3),))],
# multiple replicated axes
[(1, 8), pxla.ShardingSpec(sharding=(pxla.Unstacked(1), pxla.Chunked([2])),
mesh_mapping=(pxla.Replicated(2), pxla.ShardedAxis(0),
pxla.Replicated(2), pxla.ShardedAxis(1)))],
# replicated scalar
[(), pxla.ShardingSpec(sharding=(),
mesh_mapping=(pxla.Replicated(2), pxla.Replicated(3)))],
])
def testShardArgs(self, shape, spec, make_arg):
indices = pxla.spec_to_indices(shape, spec)
nshards = len(indices)
if jax.device_count() < nshards:
raise SkipTest
x = np.arange(prod(shape)).reshape(shape)
arg = make_arg(x)
bufs = pxla.shard_args(jax.devices()[:nshards],
[indices], [arg])
self.assertEqual(len(bufs), 1)
self.assertEqual(len(bufs[0]), nshards)
for buf, idx in zip(bufs[0], indices):
self.assertAllClose(buf.to_py(), x[idx], check_dtypes=False)
2019-01-28 11:13:34 -08:00
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())