mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
Allow pjit.AUTO to be used with jax.jit. This introduces an API change which requires a mesh to be provided to pjit.AUTO(mesh).
`with mesh:` is no longer required with pjit to use the auto spmd pass of GSPMD. PiperOrigin-RevId: 533801596
This commit is contained in:
parent
e0b5003880
commit
b71829f882
@ -62,7 +62,7 @@ from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.partition_spec import PartitionSpec
|
||||
from jax._src.sharding_impls import (
|
||||
ArrayMapping, ArrayMappingOrAutoOrUnspecified,
|
||||
AUTOAxisResource, UnspecifiedValue, UNSPECIFIED,
|
||||
AUTO, UnspecifiedValue, UNSPECIFIED,
|
||||
get_array_mapping as _get_array_mapping, is_auto, is_unspecified
|
||||
)
|
||||
from jax._src.util import (unzip3, safe_map, safe_zip, partition_list,
|
||||
@ -1693,7 +1693,7 @@ TilingMethod = Union[TileVectorize, TileManual]
|
||||
|
||||
def check_if_any_auto(
|
||||
shardings: Iterable[Union[sharding_impls.XLACompatibleSharding,
|
||||
AUTOAxisResource, UnspecifiedValue]]) -> bool:
|
||||
AUTO, UnspecifiedValue]]) -> bool:
|
||||
for s in shardings:
|
||||
if is_auto(s):
|
||||
return True
|
||||
@ -1755,8 +1755,7 @@ class DeviceAssignmentMismatchError(Exception):
|
||||
|
||||
|
||||
ShardingInfo = Tuple[
|
||||
Union[sharding_impls.XLACompatibleSharding, UnspecifiedValue,
|
||||
AUTOAxisResource],
|
||||
Union[sharding_impls.XLACompatibleSharding, UnspecifiedValue, AUTO],
|
||||
MismatchType, Optional[Any]] # Any is dispatch.SourceInfo to avoid circular imports
|
||||
|
||||
|
||||
@ -1775,13 +1774,14 @@ def _get_and_check_device_assignment(
|
||||
devices = tuple(devices)
|
||||
|
||||
for i, s_type, source_info in shardings:
|
||||
if is_auto(i) or is_unspecified(i):
|
||||
if is_unspecified(i):
|
||||
continue
|
||||
# Assign `first_sharding_info` after `AUTO` and `UNSPECIFIED` have been
|
||||
# skipped.
|
||||
|
||||
if first_sharding_info is None:
|
||||
first_sharding_info = (i._device_assignment, s_type, source_info) # type: ignore
|
||||
arr_device_assignment = i._device_assignment # type: ignore
|
||||
first_sharding_info = (
|
||||
(i.mesh._flat_devices_tuple, s_type, source_info) if is_auto(i) # type: ignore
|
||||
else (i._device_assignment, s_type, source_info)) # type: ignore
|
||||
arr_device_assignment = i.mesh._flat_devices_tuple if is_auto(i) else i._device_assignment # type: ignore
|
||||
if not devices:
|
||||
if first_sharding_info[0] != arr_device_assignment:
|
||||
raise DeviceAssignmentMismatchError([
|
||||
@ -1815,7 +1815,7 @@ def cache_wrap(fn):
|
||||
|
||||
@cache_wrap
|
||||
def _trace_to_jaxpr_and_dce(fun_or_jaxpr, global_in_avals, api_name, fun_name,
|
||||
keep_unused, donated_invars):
|
||||
keep_unused, donated_invars, auto_spmd_lowering):
|
||||
name_stack = source_info_util.new_name_stack(wrap_name(fun_name, api_name))
|
||||
|
||||
if isinstance(fun_or_jaxpr, lu.WrappedFun):
|
||||
@ -1830,7 +1830,7 @@ def _trace_to_jaxpr_and_dce(fun_or_jaxpr, global_in_avals, api_name, fun_name,
|
||||
global_out_avals = fun_or_jaxpr.out_avals
|
||||
consts = fun_or_jaxpr.consts
|
||||
|
||||
if (keep_unused or
|
||||
if (keep_unused or auto_spmd_lowering or
|
||||
any(hasattr(a, "shape") and not core.is_constant_shape(a.shape)
|
||||
for a in global_in_avals)):
|
||||
kept_var_idx = set(range(len(global_in_avals)))
|
||||
@ -2006,10 +2006,14 @@ def lower_sharding_computation(
|
||||
the singleton UNSPECIFIED to all out_avals.
|
||||
"""
|
||||
# 1. Trace to jaxpr and preprocess/verify it
|
||||
auto_spmd_lowering = (
|
||||
check_if_any_auto(in_shardings) if is_unspecified(out_shardings) else
|
||||
check_if_any_auto(it.chain.from_iterable([in_shardings, out_shardings]))) # type: ignore
|
||||
|
||||
(closed_jaxpr, global_in_avals, global_out_avals, donated_invars,
|
||||
kept_var_idx, name_stack) = _trace_to_jaxpr_and_dce(
|
||||
fun_or_jaxpr, global_in_avals, api_name, fun_name, keep_unused,
|
||||
donated_invars)
|
||||
donated_invars, auto_spmd_lowering)
|
||||
jaxpr = closed_jaxpr.jaxpr
|
||||
in_shardings = tuple(s for i, s in enumerate(in_shardings) if i in kept_var_idx)
|
||||
|
||||
@ -2091,14 +2095,13 @@ def lower_sharding_computation(
|
||||
module,
|
||||
False,
|
||||
donated_invars,
|
||||
mesh=None,
|
||||
global_in_avals=global_in_avals,
|
||||
global_out_avals=global_out_avals,
|
||||
in_shardings=in_shardings,
|
||||
out_shardings=out_shardings,
|
||||
spmd_lowering=True,
|
||||
tuple_args=tuple_args,
|
||||
auto_spmd_lowering=False,
|
||||
auto_spmd_lowering=auto_spmd_lowering,
|
||||
unordered_effects=unordered_effects,
|
||||
ordered_effects=ordered_effects,
|
||||
host_callbacks=host_callbacks,
|
||||
@ -2112,7 +2115,7 @@ def lower_sharding_computation(
|
||||
|
||||
|
||||
def _to_logical_sharding(
|
||||
aval: core.AbstractValue, sharding: Union[MaybeSharding, AUTOAxisResource]
|
||||
aval: core.AbstractValue, sharding: Union[MaybeSharding, AUTO]
|
||||
) -> Optional[sharding_impls.XLACompatibleSharding]:
|
||||
if is_unspecified(sharding) or is_auto(sharding):
|
||||
return None
|
||||
@ -2131,8 +2134,8 @@ def lower_mesh_computation(
|
||||
api_name: str,
|
||||
fun_name: str,
|
||||
mesh: Mesh,
|
||||
in_shardings: Sequence[Union[sharding_impls.NamedSharding, AUTOAxisResource]],
|
||||
out_shardings: Sequence[Union[sharding_impls.NamedSharding, AUTOAxisResource,
|
||||
in_shardings: Sequence[Union[sharding_impls.NamedSharding, AUTO]],
|
||||
out_shardings: Sequence[Union[sharding_impls.NamedSharding, AUTO,
|
||||
UnspecifiedValue]],
|
||||
donated_invars: Sequence[bool],
|
||||
spmd_lowering: bool,
|
||||
@ -2143,11 +2146,6 @@ def lower_mesh_computation(
|
||||
backend = xb.get_device_backend(mesh.devices.flat[0])
|
||||
name_stack = source_info_util.new_name_stack(wrap_name(fun_name, api_name))
|
||||
|
||||
auto_spmd_lowering = check_if_any_auto((*in_shardings, *out_shardings))
|
||||
|
||||
if auto_spmd_lowering and not spmd_lowering:
|
||||
raise ValueError('Enable spmd_lowering to use auto spmd lowering.')
|
||||
|
||||
global_axis_sizes = mesh.shape
|
||||
|
||||
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
|
||||
@ -2171,7 +2169,6 @@ def lower_mesh_computation(
|
||||
else:
|
||||
raise NotImplementedError(f"Unrecognized tiling method: {tiling_method}")
|
||||
assert not callable(out_shardings)
|
||||
assert not auto_spmd_lowering
|
||||
assert isinstance(fun_or_jaxpr, lu.WrappedFun)
|
||||
# This is the xmap path where there is no `AUTO` or `UNSPECIFIED`, which
|
||||
# is why `.spec` can be accessed.
|
||||
@ -2181,7 +2178,6 @@ def lower_mesh_computation(
|
||||
in_jaxpr_avals = global_in_avals
|
||||
else:
|
||||
assert isinstance(tiling_method, TileVectorize)
|
||||
assert not auto_spmd_lowering
|
||||
# In non-spmd lowering path, there is no `AUTO` or `UNSPECIFIED`, which is
|
||||
# why `.spec` can be accessed.
|
||||
in_tiled_avals = [tile_aval_nd(global_axis_sizes, get_array_mapping(i.spec), aval) # type: ignore
|
||||
@ -2274,14 +2270,13 @@ def lower_mesh_computation(
|
||||
lowering_result.module,
|
||||
False,
|
||||
donated_invars,
|
||||
mesh=mesh,
|
||||
global_in_avals=global_in_avals,
|
||||
global_out_avals=global_out_avals,
|
||||
in_shardings=in_shardings,
|
||||
out_shardings=out_shardings,
|
||||
spmd_lowering=spmd_lowering,
|
||||
tuple_args=tuple_args,
|
||||
auto_spmd_lowering=auto_spmd_lowering,
|
||||
auto_spmd_lowering=False,
|
||||
unordered_effects=unordered_effects,
|
||||
ordered_effects=ordered_effects,
|
||||
host_callbacks=lowering_result.host_callbacks,
|
||||
@ -2501,12 +2496,6 @@ def _cached_compilation(computation, name, mesh, spmd_lowering,
|
||||
device_assignment = da.device_assignment if isinstance(
|
||||
da, _DeviceAssignment) else da
|
||||
|
||||
dev: np.ndarray
|
||||
if auto_spmd_lowering:
|
||||
assert mesh is not None and spmd_lowering
|
||||
dev = mesh.devices
|
||||
num_replicas, num_partitions = 1, mesh.size
|
||||
else:
|
||||
# TODO(phawkins): One would normally just write:
|
||||
# dev = np.array(device_assignment)
|
||||
# The formulation below is substantially faster if there are many devices.
|
||||
@ -2610,13 +2599,10 @@ class UnloadedMeshExecutable:
|
||||
@staticmethod
|
||||
def from_hlo(name: str,
|
||||
hlo: ir.Module,
|
||||
# TODO(yashkatariya): Remove `mesh` from here once AUTO can work
|
||||
# without mesh.
|
||||
mesh: Optional[Mesh],
|
||||
global_in_avals: Sequence[ShapedArray],
|
||||
global_out_avals: Sequence[ShapedArray],
|
||||
in_shardings: Sequence[Union[sharding_impls.XLACompatibleSharding, AUTOAxisResource]],
|
||||
out_shardings: Sequence[Union[sharding_impls.XLACompatibleSharding, AUTOAxisResource,
|
||||
in_shardings: Sequence[Union[sharding_impls.XLACompatibleSharding, AUTO]],
|
||||
out_shardings: Sequence[Union[sharding_impls.XLACompatibleSharding, AUTO,
|
||||
UnspecifiedValue]],
|
||||
spmd_lowering: bool,
|
||||
tuple_args: bool,
|
||||
@ -2641,6 +2627,14 @@ class UnloadedMeshExecutable:
|
||||
device_assignment, _DeviceAssignment) else tuple(device_assignment)
|
||||
del device_assignment
|
||||
allow_prop_to_outputs = tuple(is_unspecified(o) for o in out_shardings)
|
||||
|
||||
mesh = None
|
||||
if auto_spmd_lowering:
|
||||
for i in it.chain.from_iterable([in_shardings, out_shardings]):
|
||||
if is_auto(i):
|
||||
mesh = i.mesh # type: ignore
|
||||
break
|
||||
|
||||
xla_executable, compile_options = _cached_compilation(
|
||||
hlo, name, mesh, spmd_lowering,
|
||||
tuple_args, auto_spmd_lowering, allow_prop_to_outputs,
|
||||
@ -2661,7 +2655,7 @@ class UnloadedMeshExecutable:
|
||||
assert mesh is not None
|
||||
in_shardings_xla, out_shardings_xla = _get_mesh_pspec_shardings_from_executable(
|
||||
xla_executable, mesh)
|
||||
in_shardings = [x if is_auto(i) else i
|
||||
in_shardings = [x if is_auto(i) else getattr(i, '_original_sharding', i) # type: ignore
|
||||
for x, i in safe_zip(in_shardings_xla, in_shardings)]
|
||||
out_shardings_tuple = [
|
||||
(x, True) if is_auto(o) else (o, False)
|
||||
|
@ -54,7 +54,7 @@ from jax._src.lib import xla_client as xc
|
||||
from jax._src.sharding_impls import (
|
||||
NamedSharding, XLACompatibleSharding, GSPMDSharding,
|
||||
XLADeviceAssignment, SingleDeviceSharding, PmapSharding,
|
||||
AUTOAxisResource, UNSPECIFIED, UnspecifiedValue,
|
||||
AUTO, UNSPECIFIED, UnspecifiedValue,
|
||||
ParsedPartitionSpec, SpecSync, get_single_pspec, is_auto, is_unspecified,
|
||||
is_unspecified_or_auto, prepare_axis_resources, parse_flatten_op_sharding)
|
||||
from jax._src.traceback_util import api_boundary
|
||||
@ -72,10 +72,10 @@ zip, unsafe_zip = safe_zip, zip
|
||||
|
||||
traceback_util.register_exclusion(__file__)
|
||||
|
||||
PjitSharding = Union[GSPMDSharding, UnspecifiedValue, AUTOAxisResource]
|
||||
PjitShardingMinusUnspecified = Union[GSPMDSharding, AUTOAxisResource]
|
||||
MeshSharding = Union[NamedSharding, UnspecifiedValue, AUTOAxisResource]
|
||||
MeshShardingMinusUnspecified = Union[NamedSharding, AUTOAxisResource]
|
||||
PjitSharding = Union[GSPMDSharding, UnspecifiedValue, AUTO]
|
||||
PjitShardingMinusUnspecified = Union[GSPMDSharding, AUTO]
|
||||
MeshSharding = Union[NamedSharding, UnspecifiedValue, AUTO]
|
||||
MeshShardingMinusUnspecified = Union[NamedSharding, AUTO]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -342,6 +342,7 @@ def post_infer_params(fun, infer_params_fn, static_argnums, static_argnames,
|
||||
donate_argnums) = infer_params_fn(*args, **kwargs)
|
||||
resource_env = params['resource_env']
|
||||
mesh = None if resource_env is None else resource_env.physical_mesh
|
||||
try:
|
||||
in_shardings = _resolve_in_shardings(
|
||||
args_flat, params['in_shardings'], params['out_shardings'], mesh)
|
||||
lowering = _pjit_lower(
|
||||
@ -349,6 +350,14 @@ def post_infer_params(fun, infer_params_fn, static_argnums, static_argnames,
|
||||
params['resource_env'], params['donated_invars'], params['name'],
|
||||
params['keep_unused'], params['inline'], always_lower=True,
|
||||
lowering_platform=_experimental_lowering_platform)
|
||||
except pxla.DeviceAssignmentMismatchError as e:
|
||||
fails, = e.args
|
||||
api_name = 'jit' if params['resource_env'] is None else 'pjit'
|
||||
arg_names = _get_arg_names(fun, in_tree, args_flat)
|
||||
fun_name = getattr(fun, '__qualname__', getattr(fun, '__name__', str(fun)))
|
||||
msg = _device_assignment_mismatch_error(
|
||||
fun_name, fails, args_flat, api_name, arg_names)
|
||||
raise ValueError(msg) from None
|
||||
|
||||
if kwargs:
|
||||
args_kwargs_in_tree = in_tree
|
||||
@ -1210,29 +1219,9 @@ def _pjit_lower_cached(
|
||||
mesh = None
|
||||
api_name = 'jit'
|
||||
|
||||
# Convert to `NamedSharding` when `jax_array` is not enabled. This is
|
||||
# because GDA/SDA/DA are dependent on mesh for generating outputs.
|
||||
# NamedSharding is required for host-local inputs too.
|
||||
any_auto = pxla.check_if_any_auto(it.chain(in_shardings, out_shardings))
|
||||
if any_auto:
|
||||
in_shardings: Tuple[MeshShardingMinusUnspecified, ...] = cast( # type:ignore[no-redef]
|
||||
Tuple[MeshShardingMinusUnspecified, ...], tuple(
|
||||
NamedSharding._from_parsed_pspec(
|
||||
mesh, parse_flatten_op_sharding(i._op_sharding, mesh)[0]) # type: ignore
|
||||
if isinstance(i, GSPMDSharding) else i
|
||||
for i in in_shardings
|
||||
))
|
||||
out_shardings: Tuple[MeshSharding, ...] = cast( # type: ignore[no-redef]
|
||||
Tuple[MeshSharding, ...], tuple(
|
||||
NamedSharding._from_parsed_pspec(
|
||||
mesh, parse_flatten_op_sharding(o._op_sharding, mesh)[0]) # type: ignore
|
||||
if isinstance(o, GSPMDSharding) else o
|
||||
for o in out_shardings
|
||||
))
|
||||
|
||||
# For `pjit(xmap)` cases, it needs to take the `lower_mesh_computation` path
|
||||
# because `xmap` only supports SPMDAxisContext right now.
|
||||
if any_auto or dispatch.jaxpr_has_primitive(jaxpr.jaxpr, 'xmap'):
|
||||
if dispatch.jaxpr_has_primitive(jaxpr.jaxpr, 'xmap'):
|
||||
return pxla.lower_mesh_computation(
|
||||
jaxpr, api_name, name, mesh,
|
||||
in_shardings, out_shardings, donated_invars,
|
||||
@ -1929,10 +1918,11 @@ def _fast_path_get_device_assignment(
|
||||
shardings: Iterable[PjitSharding]) -> Optional[XLADeviceAssignment]:
|
||||
da = None
|
||||
for i in shardings:
|
||||
if is_auto(i) or is_unspecified(i):
|
||||
if is_unspecified(i):
|
||||
continue
|
||||
da = i._device_assignment # type: ignore
|
||||
break
|
||||
if is_auto(i):
|
||||
return i.mesh._flat_devices_tuple # type: ignore
|
||||
return i._device_assignment # type: ignore
|
||||
return da
|
||||
|
||||
|
||||
|
@ -719,12 +719,14 @@ class GSPMDSharding(XLACompatibleSharding):
|
||||
return cls(tuple(device_assignment), proto)
|
||||
|
||||
|
||||
class AUTOAxisResource:
|
||||
pass
|
||||
AUTO = AUTOAxisResource()
|
||||
class AUTO:
|
||||
|
||||
def __init__(self, mesh: mesh_lib.Mesh):
|
||||
self.mesh = mesh
|
||||
|
||||
|
||||
def is_auto(x):
|
||||
return isinstance(x, AUTOAxisResource)
|
||||
return isinstance(x, AUTO)
|
||||
|
||||
|
||||
class UnspecifiedValue:
|
||||
@ -757,8 +759,7 @@ mesh devices without any modifications. If the mapping was {'y': 1, 'x': 1}, the
|
||||
mesh devices ndarray would have to be transposed before flattening and assignment.
|
||||
"""
|
||||
ArrayMapping = OrderedDictType[MeshAxisName, int]
|
||||
ArrayMappingOrAutoOrUnspecified = Union[ArrayMapping, AUTOAxisResource,
|
||||
UnspecifiedValue]
|
||||
ArrayMappingOrAutoOrUnspecified = Union[ArrayMapping, AUTO, UnspecifiedValue]
|
||||
|
||||
def array_mapping_to_axis_resources(array_mapping: ArrayMapping):
|
||||
if not array_mapping:
|
||||
@ -779,11 +780,11 @@ def array_mapping_to_axis_resources(array_mapping: ArrayMapping):
|
||||
return PartitionSpec(*partitions)
|
||||
|
||||
def get_array_mapping(
|
||||
axis_resources: Union[ParsedPartitionSpec, AUTOAxisResource, UnspecifiedValue]
|
||||
axis_resources: Union[ParsedPartitionSpec, AUTO, UnspecifiedValue]
|
||||
) -> ArrayMappingOrAutoOrUnspecified:
|
||||
# TODO(yashkatariya): Use `TypeGuard` on `is_auto` when it is supported.
|
||||
# Don't use `is_auto` here to satisfy pytype and mypy.
|
||||
if isinstance(axis_resources, (AUTOAxisResource, UnspecifiedValue)):
|
||||
if isinstance(axis_resources, (AUTO, UnspecifiedValue)):
|
||||
return axis_resources
|
||||
return OrderedDict((axis, i)
|
||||
for i, axes in enumerate(axis_resources)
|
||||
|
@ -1335,8 +1335,8 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
|
||||
input_data = np.arange(
|
||||
math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
|
||||
|
||||
with global_mesh:
|
||||
f = pjit(lambda x: x, in_shardings=AUTO, out_shardings=AUTO)
|
||||
f = jax.jit(lambda x: x, in_shardings=AUTO(global_mesh),
|
||||
out_shardings=AUTO(global_mesh))
|
||||
|
||||
inp = core.ShapedArray(input_data.shape, input_data.dtype)
|
||||
compiled = f.lower(inp).compile()
|
||||
@ -1355,7 +1355,8 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
|
||||
math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
|
||||
|
||||
with global_mesh:
|
||||
f = pjit(lambda x: x, in_shardings=AUTO, out_shardings=AUTO)
|
||||
f = pjit(lambda x: x, in_shardings=AUTO(global_mesh),
|
||||
out_shardings=AUTO(global_mesh))
|
||||
inp = core.ShapedArray(input_data.shape, input_data.dtype)
|
||||
compiled = f.lower(inp).compile()
|
||||
|
||||
@ -1379,7 +1380,8 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
|
||||
math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
|
||||
|
||||
with global_mesh:
|
||||
f = pjit(lambda x, y, z: (x, y, z), in_shardings=AUTO, out_shardings=AUTO)
|
||||
f = pjit(lambda x, y, z: (x, y, z), in_shardings=AUTO(global_mesh),
|
||||
out_shardings=AUTO(global_mesh))
|
||||
inp = core.ShapedArray(input_data.shape, input_data.dtype)
|
||||
compiled = f.lower(inp, inp, inp).compile()
|
||||
self.assertLen(compiled.output_shardings, 3)
|
||||
@ -1390,33 +1392,42 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
|
||||
('2d_array', (4, 2), ('x', 'y'), P('y', 'x')),
|
||||
('1d_array', (8,), ('x'), P('x')),
|
||||
)
|
||||
def test_pjit_arr_partial_auto_sharding_array(
|
||||
def test_jit_arr_partial_auto_sharding_array(
|
||||
self, mesh_shape, mesh_axis_names, pspec):
|
||||
if xla_bridge.get_backend().runtime_type == 'stream_executor':
|
||||
raise unittest.SkipTest('AutoSharding is not supported on stream_executor yet.')
|
||||
global_mesh = jtu.create_global_mesh(mesh_shape, mesh_axis_names)
|
||||
mesh = jtu.create_global_mesh(mesh_shape, mesh_axis_names)
|
||||
global_input_shape = (8, 4)
|
||||
input_data = np.arange(
|
||||
math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
|
||||
|
||||
in_resource = NamedSharding(global_mesh, pspec)
|
||||
|
||||
with global_mesh:
|
||||
f = pjit(
|
||||
inp_s = NamedSharding(mesh, pspec)
|
||||
f = jax.jit(
|
||||
lambda x, y: (x, y),
|
||||
in_shardings=(in_resource, AUTO),
|
||||
out_shardings=AUTO,
|
||||
)
|
||||
in_shardings=(inp_s, AUTO(mesh)),
|
||||
out_shardings=AUTO(mesh))
|
||||
|
||||
inp = core.ShapedArray(input_data.shape, input_data.dtype)
|
||||
compiled = f.lower(inp, inp).compile()
|
||||
inputs = [create_array(global_input_shape, global_mesh, ip, input_data)[0]
|
||||
inputs = [create_array(global_input_shape, mesh, ip, input_data)[0]
|
||||
for ip in compiled.input_shardings[0]]
|
||||
self.assertEqual(compiled.input_shardings[0][0], inp_s)
|
||||
out1, out2 = compiled(*inputs)
|
||||
for o in [out1, out2]:
|
||||
self.assertIsInstance(o, array.ArrayImpl)
|
||||
self.assertArraysEqual(o._value, input_data)
|
||||
|
||||
def test_jit_different_mesh_in_auto(self):
|
||||
mesh1 = jtu.create_global_mesh((4,), ('x',))
|
||||
dev = jax.devices()
|
||||
mesh2 = jax.sharding.Mesh([dev[0], dev[3], dev[2], dev[1]], 'x')
|
||||
f = jax.jit(lambda x, y: (x, y),
|
||||
in_shardings=(NamedSharding(mesh2, P('x')), AUTO(mesh1)))
|
||||
inp = core.ShapedArray((8, 2), np.float32)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Received incompatible devices for jitted computation"):
|
||||
f.lower(inp, inp).compile()
|
||||
|
||||
@unittest.skip('The error is not raised yet. Enable this back once we raise '
|
||||
'the error in pjit again.')
|
||||
def test_pjit_array_error(self):
|
||||
@ -1429,7 +1440,8 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
|
||||
math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
|
||||
|
||||
with global_mesh:
|
||||
f = pjit(lambda x: x, in_shardings=AUTO, out_shardings=AUTO)
|
||||
f = pjit(lambda x: x, in_shardings=AUTO(global_mesh),
|
||||
out_shardings=AUTO(global_mesh))
|
||||
|
||||
inp = core.ShapedArray(input_data.shape, input_data.dtype)
|
||||
compiled = f.lower(inp).compile()
|
||||
|
Loading…
x
Reference in New Issue
Block a user