mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Make the gda and xmap sharding check work generally by checking the OpSharding protos.
PiperOrigin-RevId: 475560097
This commit is contained in:
parent
24bc153e49
commit
fc2902c6ac
@ -20,7 +20,7 @@ from collections import OrderedDict, abc
|
||||
from typing import (Callable, Iterable, Tuple, Optional, Dict, Any, Set,
|
||||
NamedTuple, Union, Sequence)
|
||||
from warnings import warn
|
||||
from functools import wraps, partial, partialmethod
|
||||
from functools import wraps, partial, partialmethod, lru_cache
|
||||
from enum import Enum
|
||||
|
||||
from jax import numpy as jnp
|
||||
@ -1811,6 +1811,17 @@ def _check_out_avals_vs_out_axes(out_avals: Sequence[core.AbstractValue],
|
||||
def _check_gda_or_array_xmap_partitioning(axis_resources, resource_env,
|
||||
global_axis_sizes, in_axes_flat,
|
||||
in_positional_semantics, args_flat):
|
||||
@lru_cache()
|
||||
def _check_sharding(in_sharding, xmap_sharding, ndim, arr_flavor):
|
||||
if not pxla.are_op_shardings_equal(
|
||||
in_sharding._to_xla_op_sharding(ndim),
|
||||
xmap_sharding._to_xla_op_sharding(ndim)):
|
||||
raise ValueError(
|
||||
f"Got an input {arr_flavor} to xmap with different partitioning than "
|
||||
"specified in xmap. The partitioning must match. "
|
||||
f"Got {arr_flavor} spec: {in_sharding.spec} and "
|
||||
f"xmap spec: {xmap_sharding.spec}")
|
||||
|
||||
mesh_in_axes = EvaluationPlan.from_axis_resources(
|
||||
axis_resources, resource_env, global_axis_sizes,
|
||||
in_positional_semantics).to_mesh_axes(in_axes_flat)
|
||||
@ -1825,15 +1836,16 @@ def _check_gda_or_array_xmap_partitioning(axis_resources, resource_env,
|
||||
f"Got xmap mesh: {resource_env.physical_mesh},\n"
|
||||
f"{arr_flavor} mesh: {mesh}")
|
||||
|
||||
array_mapping = pxla._get_array_mapping(
|
||||
arg.mesh_axes if arr_flavor == 'GDA' else arg.sharding.spec)
|
||||
if array_mapping != xmap_array_mapping:
|
||||
raise ValueError(
|
||||
f"Got an input {arr_flavor} to xmap with different partitioning than "
|
||||
"specified in xmap. The partitioning must match. "
|
||||
f"Got {arr_flavor} spec: {pxla.array_mapping_to_axis_resources(array_mapping)} and "
|
||||
f"xmap spec: {pxla.array_mapping_to_axis_resources(xmap_array_mapping)} "
|
||||
f"for {arr_flavor}: {arg}")
|
||||
if arr_flavor == 'GDA':
|
||||
s = pxla._create_mesh_pspec_sharding(arg.mesh, arg.mesh_axes)
|
||||
else:
|
||||
s = arg.sharding
|
||||
xmap_sharding = pxla._create_mesh_pspec_sharding(
|
||||
mesh, pxla.array_mapping_to_axis_resources(xmap_array_mapping))
|
||||
# This check is cached because comparing OpSharding is expensive during
|
||||
# dispatch and if the shardings are the same, then there is no need to
|
||||
# compare twice.
|
||||
_check_sharding(s, xmap_sharding, arg.ndim, arr_flavor)
|
||||
|
||||
|
||||
# TODO: We should relax this at least for "constructor primitives"
|
||||
|
@ -989,9 +989,13 @@ class NamedNNTest(XMapTestCase):
|
||||
atol=1e-4, rtol=2e-2)
|
||||
|
||||
|
||||
@jax_config.jax_array(False)
|
||||
class XMapGDATest(XMapTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
if config.jax_array:
|
||||
self.skipTest('GDA and Array cannot be enabled together.')
|
||||
|
||||
@jtu.with_mesh([('x', 4), ('y', 2)])
|
||||
def test_basic(self):
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
@ -1123,6 +1127,27 @@ class XMapGDATest(XMapTestCase):
|
||||
'specified in xmap. The partitioning must match.')):
|
||||
f(gda_obj)
|
||||
|
||||
def test_gda_from_pjit_with_xmap_sharding_mismatch(self):
|
||||
global_mesh = jtu.create_global_mesh((8, 1), ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
mesh_axes = P('x', 'y')
|
||||
input_data = np.arange(
|
||||
prod(global_input_shape)).reshape(global_input_shape)
|
||||
gda_obj = global_device_array.GlobalDeviceArray.from_callback(
|
||||
global_input_shape, global_mesh, mesh_axes, lambda idx: input_data[idx])
|
||||
with jax_config.parallel_functions_output_gda(True):
|
||||
with global_mesh:
|
||||
out = pjit(lambda x: x, in_axis_resources=P('x', 'y'),
|
||||
out_axis_resources=P('x', 'y'))(gda_obj)
|
||||
|
||||
xmap_out = maps.xmap(
|
||||
lambda x: x,
|
||||
in_axes=({0: "a", 1: "b"}),
|
||||
out_axes=({0: "a", 1: "b"}),
|
||||
axis_resources={"a": "x", "b": "y"})(out) # doesn't crash
|
||||
self.assertArraysEqual(xmap_out, input_data)
|
||||
|
||||
|
||||
|
||||
class XMapArrayTest(XMapTestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user