Make the gda and xmap sharding check work generally by checking the OpSharding protos.

PiperOrigin-RevId: 475560097
This commit is contained in:
Yash Katariya 2022-09-20 08:24:08 -07:00 committed by jax authors
parent 24bc153e49
commit fc2902c6ac
2 changed files with 48 additions and 11 deletions

View File

@ -20,7 +20,7 @@ from collections import OrderedDict, abc
from typing import (Callable, Iterable, Tuple, Optional, Dict, Any, Set,
NamedTuple, Union, Sequence)
from warnings import warn
from functools import wraps, partial, partialmethod
from functools import wraps, partial, partialmethod, lru_cache
from enum import Enum
from jax import numpy as jnp
@ -1811,6 +1811,17 @@ def _check_out_avals_vs_out_axes(out_avals: Sequence[core.AbstractValue],
def _check_gda_or_array_xmap_partitioning(axis_resources, resource_env,
global_axis_sizes, in_axes_flat,
in_positional_semantics, args_flat):
@lru_cache()
def _check_sharding(in_sharding, xmap_sharding, ndim, arr_flavor):
if not pxla.are_op_shardings_equal(
in_sharding._to_xla_op_sharding(ndim),
xmap_sharding._to_xla_op_sharding(ndim)):
raise ValueError(
f"Got an input {arr_flavor} to xmap with different partitioning than "
"specified in xmap. The partitioning must match. "
f"Got {arr_flavor} spec: {in_sharding.spec} and "
f"xmap spec: {xmap_sharding.spec}")
mesh_in_axes = EvaluationPlan.from_axis_resources(
axis_resources, resource_env, global_axis_sizes,
in_positional_semantics).to_mesh_axes(in_axes_flat)
@ -1825,15 +1836,16 @@ def _check_gda_or_array_xmap_partitioning(axis_resources, resource_env,
f"Got xmap mesh: {resource_env.physical_mesh},\n"
f"{arr_flavor} mesh: {mesh}")
array_mapping = pxla._get_array_mapping(
arg.mesh_axes if arr_flavor == 'GDA' else arg.sharding.spec)
if array_mapping != xmap_array_mapping:
raise ValueError(
f"Got an input {arr_flavor} to xmap with different partitioning than "
"specified in xmap. The partitioning must match. "
f"Got {arr_flavor} spec: {pxla.array_mapping_to_axis_resources(array_mapping)} and "
f"xmap spec: {pxla.array_mapping_to_axis_resources(xmap_array_mapping)} "
f"for {arr_flavor}: {arg}")
if arr_flavor == 'GDA':
s = pxla._create_mesh_pspec_sharding(arg.mesh, arg.mesh_axes)
else:
s = arg.sharding
xmap_sharding = pxla._create_mesh_pspec_sharding(
mesh, pxla.array_mapping_to_axis_resources(xmap_array_mapping))
# This check is cached because comparing OpSharding is expensive during
# dispatch and if the shardings are the same, then there is no need to
# compare twice.
_check_sharding(s, xmap_sharding, arg.ndim, arr_flavor)
# TODO: We should relax this at least for "constructor primitives"

View File

@ -989,9 +989,13 @@ class NamedNNTest(XMapTestCase):
atol=1e-4, rtol=2e-2)
@jax_config.jax_array(False)
class XMapGDATest(XMapTestCase):
def setUp(self):
super().setUp()
if config.jax_array:
self.skipTest('GDA and Array cannot be enabled together.')
@jtu.with_mesh([('x', 4), ('y', 2)])
def test_basic(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
@ -1123,6 +1127,27 @@ class XMapGDATest(XMapTestCase):
'specified in xmap. The partitioning must match.')):
f(gda_obj)
def test_gda_from_pjit_with_xmap_sharding_mismatch(self):
global_mesh = jtu.create_global_mesh((8, 1), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = P('x', 'y')
input_data = np.arange(
prod(global_input_shape)).reshape(global_input_shape)
gda_obj = global_device_array.GlobalDeviceArray.from_callback(
global_input_shape, global_mesh, mesh_axes, lambda idx: input_data[idx])
with jax_config.parallel_functions_output_gda(True):
with global_mesh:
out = pjit(lambda x: x, in_axis_resources=P('x', 'y'),
out_axis_resources=P('x', 'y'))(gda_obj)
xmap_out = maps.xmap(
lambda x: x,
in_axes=({0: "a", 1: "b"}),
out_axes=({0: "a", 1: "b"}),
axis_resources={"a": "x", "b": "y"})(out) # doesn't crash
self.assertArraysEqual(xmap_out, input_data)
class XMapArrayTest(XMapTestCase):