Merge pull request #5672 from skye:pjit

PiperOrigin-RevId: 356774892
This commit is contained in:
jax authors 2021-02-10 10:57:06 -08:00
commit 14cc3f89a5
5 changed files with 422 additions and 26 deletions

View File

@ -102,6 +102,13 @@ pytype_library(
deps = [":jax"],
)
pytype_library(
name = "pjit",
srcs = ["experimental/pjit.py"],
srcs_version = "PY3",
deps = [":jax"],
)
pytype_library(
name = "jet",
srcs = ["experimental/jet.py"],

View File

@ -546,15 +546,16 @@ def make_xmap_callable(fun: lu.WrappedFun,
if used_mesh_axes:
submesh = resource_env.physical_mesh[sorted(used_mesh_axes, key=str)]
mesh_in_axes, mesh_out_axes = plan.to_mesh_axes(in_axes, out_axes)
return pxla.mesh_tiled_callable(f,
name,
backend,
submesh,
mesh_in_axes,
mesh_out_axes,
donated_invars,
EXPERIMENTAL_SPMD_LOWERING,
*in_avals)
return pxla.mesh_callable(f,
name,
backend,
submesh,
mesh_in_axes,
lambda: mesh_out_axes,
donated_invars,
EXPERIMENTAL_SPMD_LOWERING,
*in_avals,
tile_by_mesh_axes=True)
else:
return xla._xla_callable(f, None, backend, name, donated_invars,
*((a, None) for a in in_avals))
@ -821,10 +822,10 @@ def _xmap_translation_rule_spmd(c, axis_env,
call_jaxpr, name,
in_axes, out_axes, donated_invars,
axis_sizes, axis_resources, resource_env, backend):
# TODO(apaszke): This is quite difficult to implement given the current lowering
# in mesh_tiled_callable. There, we vmap the mapped axes, but we
# have no idea which positional axes they end up being in this
# translation rule!
# TODO(apaszke): This is quite difficult to implement given the current
# lowering in mesh_callable. There, we vmap the mapped axes,
# but we have no idea which positional axes they end up being
# in this translation rule!
raise NotImplementedError

212
jax/experimental/pjit.py Normal file
View File

@ -0,0 +1,212 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict
from typing import (Callable, Optional, Sequence, Tuple, Union)
from warnings import warn
from . import maps
from .. import core
from .. import linear_util as lu
from ..api import _check_callable, _check_arg
from ..api_util import (argnums_partial_except, flatten_axes,
flatten_fun_nokwargs, _ensure_index_tuple,
donation_vector, rebase_donate_argnums)
from ..interpreters import ad
from ..interpreters import pxla
from ..interpreters import xla
from ..lib import xla_bridge as xb
from ..lib import xla_client as xc
from ..tree_util import tree_flatten, tree_unflatten
from .._src.util import (extend_name_stack, HashableFunction, safe_zip,
wrap_name, wraps)
xops = xc._xla.ops
def pjit(fun: Callable,
in_axis_resources,
out_axis_resources,
static_argnums: Union[int, Sequence[int]] = (),
donate_argnums: Union[int, Sequence[int]] = ()):
warn("pjit is an experimental feature and probably has bugs!")
_check_callable(fun)
# To be a tree prefix of the positional args tuple, in_axes can never be a
# list: if in_axes is not a leaf, it must be a tuple of trees. However,
# in cases like these users expect tuples and lists to be treated
# essentially interchangeably, so we canonicalize lists to tuples here
# rather than raising an error. https://github.com/google/jax/issues/2367
if isinstance(in_axis_resources, list):
in_axis_resources = tuple(in_axis_resources)
if isinstance(out_axis_resources, list):
out_axis_resources = tuple(out_axis_resources)
static_argnums = _ensure_index_tuple(static_argnums)
donate_argnums = _ensure_index_tuple(donate_argnums)
donate_argnums = rebase_donate_argnums(donate_argnums, static_argnums)
@wraps(fun)
def wrapped(*args, **kwargs):
if kwargs:
raise NotImplementedError("pjit over kwargs not yet supported")
if max(static_argnums + donate_argnums, default=-1) >= len(args):
raise ValueError(f"jitted function has static_argnums={static_argnums}, "
f"donate_argnums={donate_argnums} but "
f"was called with only {len(args)} positional arguments.")
# Putting this outside of wrapped would make resources lexically scoped
resource_env = maps.thread_resources.env
f = lu.wrap_init(fun)
if static_argnums:
f, dyn_args = argnums_partial_except(f, static_argnums, args)
else:
dyn_args = args
args_flat, in_tree = tree_flatten(args)
for arg in args_flat: _check_arg(arg)
flat_fun, out_tree = flatten_fun_nokwargs(f, in_tree)
if donate_argnums:
donated_invars = donation_vector(donate_argnums, dyn_args, ())
else:
donated_invars = (False,) * len(args_flat)
in_axis_resources_flat = tuple(flatten_axes("pjit in_axis_resources",
in_tree, in_axis_resources))
out_axis_resources_thunk = HashableFunction(
lambda: tuple(flatten_axes("pjit out_axis_resources", out_tree(),
out_axis_resources)),
closure=out_axis_resources)
out = pjit_call_p.bind(
flat_fun,
*args_flat,
in_axis_resources=in_axis_resources_flat,
out_axis_resources_thunk=out_axis_resources_thunk,
resource_env=resource_env,
donated_invars=donated_invars,
name=flat_fun.__name__)
return tree_unflatten(out_tree(), out)
return wrapped
def _pjit_call_impl(fun: lu.WrappedFun, *args, in_axis_resources,
out_axis_resources_thunk, resource_env, donated_invars,
name):
in_avals = [core.raise_to_shaped(core.get_aval(arg)) for arg in args]
return _pjit_callable(
fun, in_axis_resources, out_axis_resources_thunk, resource_env,
donated_invars, name, *in_avals)(*args)
def _pjit_translation_rule(c, axis_env, in_nodes, name_stack, backend, name,
call_jaxpr, in_axis_resources, out_axis_resources_thunk,
resource_env, donated_invars):
mesh = resource_env.physical_mesh
subc = xc.XlaBuilder(f"pjit_{name}")
args = []
for i, (n, axis_resources) in enumerate(safe_zip(in_nodes, in_axis_resources)):
# N.B. inlined calls shouldn't have shardings set directly on the inputs or
# outputs (set_sharding_proto adds an identity operation).
arg = xb.parameter(subc, i, c.GetShape(n))
args.append(xb.set_sharding_proto(subc, arg,
get_sharding_proto(c, n, axis_resources, mesh)))
out_nodes = xla.jaxpr_subcomp(
subc, call_jaxpr, backend, axis_env, (),
extend_name_stack(name_stack, wrap_name(name, "pjit")), *args)
out_axis_resources = out_axis_resources_thunk()
out_nodes = [xb.set_sharding_proto(subc, out,
get_sharding_proto(c, n, axis_resources, mesh))
for out, axis_resources in safe_zip(out_nodes, out_axis_resources)]
subc = subc.build(xops.Tuple(subc, out_nodes))
return xops.Call(c, subc, list(in_nodes))
pjit_call_p = core.CallPrimitive("pjit_call")
pjit_call_p.def_impl(_pjit_call_impl)
xla.call_translations[pjit_call_p] = _pjit_translation_rule
# None indicates unpartitioned dimension
ArrayAxisPartitioning = Union[pxla.MeshAxisName, Tuple[pxla.MeshAxisName, ...], None]
# None indicates fully replicated array value
ArrayPartitioning = Optional[Tuple[ArrayAxisPartitioning, ...]]
@lu.cache
def _pjit_callable(
fun: lu.WrappedFun,
in_axis_resources: Tuple[ArrayPartitioning, ...],
out_axis_resources_thunk: Callable[[], Tuple[ArrayPartitioning, ...]],
resource_env,
donated_invars,
name: str,
*in_avals):
in_axes = [get_array_mapping(axes) for axes in in_axis_resources]
out_axes_thunk = lambda: [get_array_mapping(axes)
for axes in out_axis_resources_thunk()]
# TODO(skye): allow for using a submesh of physical_mesh
return pxla.mesh_callable(fun, name, None, resource_env.physical_mesh,
in_axes, out_axes_thunk, donated_invars,
True, *in_avals, tile_by_mesh_axes=False)
def with_sharding_constraint(x, axis_resources):
resource_env = maps.thread_resources.env
return sharding_constraint_p.bind(x, axis_resources=axis_resources,
resource_env=resource_env)
def _sharding_constraint_impl(x, axis_resources, resource_env):
# TODO(skye): can we also prevent this from being called in other
# non-pjit contexts? (e.g. pmap, control flow)
raise NotImplementedError(
"with_sharding_constraint() should only be called inside pjit()")
def _sharding_constraint_translation_rule(c, x_node, axis_resources, resource_env):
mesh = resource_env.physical_mesh
return xb.set_sharding_proto(c, x_node,
get_sharding_proto(c, x_node, axis_resources, mesh))
sharding_constraint_p = core.Primitive("sharding_constraint")
sharding_constraint_p.def_impl(_sharding_constraint_impl)
sharding_constraint_p.def_abstract_eval(lambda x, **unused: x)
ad.deflinear2(sharding_constraint_p,
lambda ct, _, axis_resources, resource_env: (
with_sharding_constraint(ct, axis_resources),))
xla.translations[sharding_constraint_p] = _sharding_constraint_translation_rule
def get_array_mapping(axis_resources: ArrayPartitioning) -> pxla.ArrayMapping:
if axis_resources is None:
return OrderedDict()
return OrderedDict(entry
for i, axis_or_axes in enumerate(axis_resources)
for entry in _array_mapping_entries(axis_or_axes, i))
def _array_mapping_entries(partitioning: ArrayAxisPartitioning, i: int):
if partitioning is None:
return
if not isinstance(partitioning, (list, tuple)):
yield (partitioning, i)
else:
for axis in partitioning:
assert isinstance(axis, str)
yield (axis, i)
def get_sharding_proto(c, xla_op, axis_resources, mesh):
xla_shape = c.GetShape(xla_op)
aval = core.ShapedArray(xla_shape.dimensions(), xla_shape.element_type())
array_mapping = get_array_mapping(axis_resources)
sharding_spec = pxla.mesh_sharding_specs(mesh.shape, mesh.axis_names)(
aval, array_mapping)
return sharding_spec.sharding_proto()

View File

@ -1383,15 +1383,16 @@ def untile_aval_nd(axis_sizes, out_axes: ArrayMapping, aval):
shape[axis] *= axis_sizes[name]
return aval.update(shape=tuple(shape))
def mesh_tiled_callable(fun: lu.WrappedFun,
transformed_name: str,
backend_name: Optional[str],
mesh: Mesh,
in_axes: Sequence[ArrayMapping],
out_axes: Sequence[ArrayMapping],
donated_invars: Sequence[bool],
spmd_lowering,
*local_in_untiled_avals):
def mesh_callable(fun: lu.WrappedFun,
transformed_name: str,
backend_name: Optional[str],
mesh: Mesh,
in_axes: Sequence[ArrayMapping],
out_axes_thunk: Callable[[], Sequence[ArrayMapping]],
donated_invars: Sequence[bool],
spmd_lowering: bool,
*local_in_untiled_avals,
tile_by_mesh_axes: bool):
assert config.omnistaging_enabled
local_mesh = mesh.local_mesh
global_axis_sizes = mesh.shape
@ -1408,10 +1409,11 @@ def mesh_tiled_callable(fun: lu.WrappedFun,
if spmd_lowering:
# TODO: Consider handling xmap's 'vectorize' in here. We can vmap once instead of vtile twice!
for name, size in reversed(mesh.shape.items()):
fun = vtile(fun,
tuple(a.get(name, None) for a in in_axes),
tuple(a.get(name, None) for a in out_axes),
tile_size=size, axis_name=name)
if tile_by_mesh_axes:
fun = vtile(fun,
tuple(a.get(name, None) for a in in_axes),
tuple(a.get(name, None) for a in out_axes_thunk()),
tile_size=size, axis_name=name)
global_in_untiled_avals = [untile_aval_nd(global_axis_sizes, aval_in_axes, aval)
for aval, aval_in_axes in safe_zip(in_tiled_avals, in_axes)]
in_jaxpr_avals = global_in_untiled_avals
@ -1419,6 +1421,7 @@ def mesh_tiled_callable(fun: lu.WrappedFun,
in_jaxpr_avals = in_tiled_avals
with core.extend_axis_env_nd(mesh.shape.items()):
jaxpr, out_jaxpr_avals, consts = pe.trace_to_jaxpr_final(fun, in_jaxpr_avals)
out_axes = out_axes_thunk()
assert len(out_axes) == len(out_jaxpr_avals)
if spmd_lowering:
global_out_untiled_avals = out_jaxpr_avals

173
tests/pjit_test.py Normal file
View File

@ -0,0 +1,173 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from functools import partial
from typing import Generator, List, Tuple
from unittest import SkipTest
from absl.testing import absltest
import numpy as np
import jax
import jax.numpy as jnp
from jax import test_util as jtu
# TODO(skye): do we still wanna call this PartitionSpec?
from jax.experimental import PartitionSpec as P
from jax.experimental.maps import mesh
from jax.experimental.pjit import pjit, with_sharding_constraint
from jax.interpreters import pxla
from jax._src.util import unzip2, prod
from jax.config import config
config.parse_flags_with_absl()
# TODO(skye): move into test_util and dedup with xmap_test.py
MeshSpec = List[Tuple[str, int]]
@contextmanager
def with_mesh(named_shape: MeshSpec) -> Generator[None, None, None]:
"""Test utility for setting up meshes given mesh data from `schedules`."""
# This is similar to the `with_mesh` function above, but isn't a decorator.
axis_names, shape = unzip2(named_shape)
size = prod(shape)
local_devices = list(jax.local_devices())
if len(local_devices) < size:
raise SkipTest(f"Test requires {size} local devices")
mesh_devices = np.array(local_devices[:size]).reshape(shape)
with mesh(mesh_devices, axis_names):
yield
# TODO(skye): make the buffer donation utils part of JaxTestCase
class PJitTest(jtu.BufferDonationTestCase):
@with_mesh([('x', 2)])
def testBasic1D(self):
@partial(pjit,
in_axis_resources=(P('x'), P('x')),
out_axis_resources=None)
def f(x, y):
return x + y
shape = (8, 8)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
actual = f(x, x + 1)
expected = x + (x + 1)
self.assertAllClose(actual, expected, check_dtypes=False)
self.assertIsInstance(actual, pxla.ShardedDeviceArray)
self.assertLen(actual.device_buffers, 2)
self.assertAllClose(actual.device_buffers[0].to_py(), expected,
check_dtypes=False)
@with_mesh([('x', 2), ('y', 2)])
def testBasic2D(self):
@partial(pjit,
in_axis_resources=(P(None, 'x', 'y'), P('y')),
out_axis_resources=P('x'))
def f(x, y):
return x @ y
x_shape = (8, 6, 4)
y_shape = (4, 2)
x = jnp.arange(np.prod(x_shape)).reshape(x_shape)
y = jnp.arange(np.prod(y_shape)).reshape(y_shape)
actual = f(x, y)
expected = x @ y
self.assertAllClose(actual, expected, check_dtypes=False)
self.assertIsInstance(actual, pxla.ShardedDeviceArray)
self.assertLen(actual.device_buffers, 4)
split0, split1 = np.split(expected, 2)
self.assertAllClose(actual.device_buffers[0].to_py(), split0,
check_dtypes=False)
self.assertAllClose(actual.device_buffers[1].to_py(), split0,
check_dtypes=False)
self.assertAllClose(actual.device_buffers[2].to_py(), split1,
check_dtypes=False)
self.assertAllClose(actual.device_buffers[3].to_py(), split1,
check_dtypes=False)
@with_mesh([('x', 2), ('y', 2)])
def testTwoMeshAxisSharding(self):
@partial(pjit,
in_axis_resources=P(('x', 'y'),),
out_axis_resources=P(('x', 'y'),))
def f(x, y):
return x @ y
shape = (8, 8)
x = jnp.arange(np.prod(shape)).reshape(shape)
actual = f(x, x + 1)
expected = x @ (x + 1)
self.assertAllClose(actual, expected, check_dtypes=False)
self.assertIsInstance(actual, pxla.ShardedDeviceArray)
self.assertLen(actual.device_buffers, 4)
splits = np.split(expected, 4)
self.assertAllClose(actual.device_buffers[0].to_py(), splits[0],
check_dtypes=False)
self.assertAllClose(actual.device_buffers[1].to_py(), splits[1],
check_dtypes=False)
self.assertAllClose(actual.device_buffers[2].to_py(), splits[2],
check_dtypes=False)
self.assertAllClose(actual.device_buffers[3].to_py(), splits[3],
check_dtypes=False)
@with_mesh([('x', 2)])
def testBufferDonation(self):
@partial(pjit,
in_axis_resources=P('x'),
out_axis_resources=P('x'),
donate_argnums=0)
def f(x, y):
return x + y
shard = pjit(lambda x: x, in_axis_resources=P('x'),
out_axis_resources=P('x'))
x = shard(jnp.ones((2, 5)) * 4)
y = shard(jnp.ones((2, 5)) * 2)
expected = x + y
self.assertAllClose(f(x, y), expected)
self.assertNotDeleted(y)
self.assertDeleted(x)
@with_mesh([('x', 2), ('y', 1)])
def testShardingConstraint(self):
@partial(pjit, in_axis_resources=None, out_axis_resources=None)
def f(x):
y = x + 1
y = with_sharding_constraint(y, P('x', 'y'))
return y * 2
shape = (8, 8)
x = np.arange(prod(shape)).reshape(shape)
expected = (x + 1) * 2
actual = f(x)
self.assertAllClose(actual, expected, check_dtypes=False)
self.assertIsInstance(actual, pxla.ShardedDeviceArray)
self.assertLen(actual.device_buffers, 2)
self.assertAllClose(actual.device_buffers[0].to_py(), expected,
check_dtypes=False)
hlo = jax.xla_computation(f)(np.ones(shape))
# Annotation from with_sharding_constraint
self.assertIn("sharding={devices=[2,1]0,1}", hlo.as_hlo_text())
# Annotation from pjit
self.assertIn("sharding={replicated}", hlo.as_hlo_text())
# TODO(skye): add more unit tests once API is more finalized
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())