mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
If input layouts are specified via in_shardings
to jit
and the array that the jitted function is called with is uncommitted, reshard the input array to the layout specified by the user.
Not doing the resharding, leads to incorrect outputs on GPU and a crash on TPU which is not good. Fixes: https://github.com/google/jax/issues/23100 PiperOrigin-RevId: 665000157
This commit is contained in:
parent
292161ab4d
commit
6e1c23610d
@ -1827,7 +1827,7 @@ def _cpp_pmap(
|
||||
|
||||
cpp_mapped_f = pmap_lib.pmap(
|
||||
fun, cache_miss, static_broadcasted_tuple,
|
||||
lambda x, s: pxla.shard_args([s], [x])[0],
|
||||
lambda x, s: pxla.shard_args([s], [None], [x])[0],
|
||||
pytree_registry=tree_util.default_registry)
|
||||
_pmap_cache_clears.add(cpp_mapped_f)
|
||||
|
||||
|
@ -1086,9 +1086,8 @@ def shard_sharded_device_array_slow_path(x, devices, indices, sharding):
|
||||
# Look up all buffers that contain the correct slice of the logical array.
|
||||
candidates_list = candidates[hashed_index(idx)]
|
||||
if not candidates_list:
|
||||
# This array isn't sharded correctly. Reshard it via host roundtrip.
|
||||
# TODO(skye): more efficient reshard?
|
||||
return pxla.shard_args([sharding], [x._value], canonicalize=False)[0]
|
||||
return pxla.shard_args([sharding], [None], [x._value],
|
||||
canonicalize=False)[0]
|
||||
# Try to find a candidate buffer already on the correct device,
|
||||
# otherwise copy one of them.
|
||||
for buf in candidates_list:
|
||||
@ -1097,7 +1096,6 @@ def shard_sharded_device_array_slow_path(x, devices, indices, sharding):
|
||||
break
|
||||
else:
|
||||
bufs.append(buf)
|
||||
|
||||
return pxla.batched_device_put(x.aval, sharding, bufs, devices)
|
||||
|
||||
|
||||
@ -1107,24 +1105,30 @@ def _sharding_indices_and_eq(src_sharding, shape, dst_sharding):
|
||||
dst_indices = dst_sharding.addressable_devices_indices_map(shape).values()
|
||||
return dst_indices, tuple(src_indices) == tuple(dst_indices)
|
||||
|
||||
def _layout_eq(x, dst_layout, sharding):
|
||||
if pxla.is_default_layout(dst_layout, sharding, x.aval):
|
||||
return True
|
||||
return x.layout.device_local_layout == dst_layout
|
||||
|
||||
def _array_shard_arg(xs, shardings):
|
||||
|
||||
def _array_shard_arg(xs, shardings, layouts):
|
||||
results = []
|
||||
batch_xs, batch_devs, batch_shardings, batch_indices = [], [], [], []
|
||||
for i, (x, sharding) in enumerate(safe_zip(xs, shardings)):
|
||||
x._check_if_deleted()
|
||||
|
||||
indices, same_indices = _sharding_indices_and_eq(
|
||||
x.sharding, x.shape, sharding)
|
||||
for i, (x, sharding, layout) in enumerate(safe_zip(xs, shardings, layouts)):
|
||||
x._check_if_deleted()
|
||||
indices, same_indices = _sharding_indices_and_eq(x.sharding, x.shape, sharding)
|
||||
same_layout = _layout_eq(x, layout, sharding)
|
||||
|
||||
if not x.is_fully_addressable:
|
||||
if same_indices:
|
||||
if same_indices and same_layout:
|
||||
results.append(x)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Cannot reshard an input that is not fully addressable")
|
||||
else:
|
||||
devices = sharding._addressable_device_assignment
|
||||
if same_indices:
|
||||
if same_indices and same_layout:
|
||||
# Add a placeholder result that will be filled in later.
|
||||
results.append(None)
|
||||
# Accumulate arguments to `batched_copy_array_to_devices_with_sharding`.
|
||||
@ -1133,6 +1137,8 @@ def _array_shard_arg(xs, shardings):
|
||||
batch_shardings.append(sharding)
|
||||
batch_indices.append(i)
|
||||
# Resharding starts here:
|
||||
elif not same_layout:
|
||||
results.append(api.device_put(x, Layout(layout, sharding)))
|
||||
elif dispatch.is_single_device_sharding(x.sharding):
|
||||
results.append(shard_device_array(x, devices, indices, sharding))
|
||||
else:
|
||||
@ -1145,8 +1151,6 @@ def _array_shard_arg(xs, shardings):
|
||||
assert results[i] is None
|
||||
results[i] = copy_out
|
||||
return results
|
||||
|
||||
|
||||
pxla.shard_arg_handlers[ArrayImpl] = _array_shard_arg
|
||||
|
||||
|
||||
@ -1178,8 +1182,8 @@ pxla.local_result_handlers[core.ConcreteArray] = _array_local_result_handler
|
||||
|
||||
# Token handlers
|
||||
|
||||
def _token_shard_arg(xs, shardings):
|
||||
return _array_shard_arg([x._buf for x in xs], shardings)
|
||||
def _token_shard_arg(xs, shardings, layouts):
|
||||
return _array_shard_arg([x._buf for x in xs], shardings, layouts)
|
||||
pxla.shard_arg_handlers[core.Token] = _token_shard_arg
|
||||
|
||||
|
||||
|
@ -134,7 +134,7 @@ class RuntimeTokenSet(threading.local):
|
||||
# We only use replicated sharding for the first time when the token for the
|
||||
# order effect hasn't been created.
|
||||
s = jax.sharding.GSPMDSharding.get_replicated(devices)
|
||||
sharded_tok = core.Token(pxla.shard_args([s], [tok])[0])
|
||||
sharded_tok = core.Token(pxla.shard_args([s], [None], [tok])[0])
|
||||
self.current_tokens[eff] = sharded_tok
|
||||
return sharded_tok
|
||||
|
||||
@ -515,7 +515,10 @@ def _batched_device_put_impl(
|
||||
if shard_arg_xs:
|
||||
# Batch shard_arg calls. Helps improve efficiency for backends that support
|
||||
# efficient batch transfer.
|
||||
shard_arg_results = pxla.shard_args(shard_arg_shardings, shard_arg_xs)
|
||||
# device_put handles `Layout` via a different path, so just pass `None` as
|
||||
# the layout here.
|
||||
shard_arg_results = pxla.shard_args(
|
||||
shard_arg_shardings, [None] * len(shard_arg_xs), shard_arg_xs)
|
||||
for i, shard_arg_result in zip(shard_arg_indices, shard_arg_results):
|
||||
assert isinstance(ys[i], _DeferredShardArg)
|
||||
ys[i] = ys[i].result_handler(shard_arg_result)
|
||||
|
@ -104,11 +104,12 @@ class EArray(basearray.Array):
|
||||
|
||||
# TODO(mattjj): _set_array_base_attributes
|
||||
|
||||
def _earray_shard_arg_handler(xs, shardings):
|
||||
def _earray_shard_arg_handler(xs, shardings, layouts):
|
||||
arrs = [x._data for x in xs]
|
||||
phys_shardings = [sharding_impls.physical_sharding(x.aval, sharding)
|
||||
for x, sharding in zip(xs, shardings)]
|
||||
return pxla.shard_args(phys_shardings, arrs)
|
||||
# TODO(yashkatariya): `layouts` should be converted to physical layouts.
|
||||
return pxla.shard_args(phys_shardings, layouts, arrs)
|
||||
pxla.shard_arg_handlers[EArray] = _earray_shard_arg_handler
|
||||
|
||||
api_util._shaped_abstractify_handlers[EArray] = lambda self: self.aval
|
||||
|
@ -1000,7 +1000,7 @@ def _to_xla_layout(layout: DeviceLocalLayout | None | AutoLayout,
|
||||
return "auto"
|
||||
if aval is core.abstract_token:
|
||||
return "default"
|
||||
return layout._to_xla_layout(aval.dtype) # type: ignore
|
||||
return str(layout._to_xla_layout(aval.dtype)) # type: ignore
|
||||
|
||||
|
||||
def _get_mem_kind(s: JSharding | None) -> str | None:
|
||||
|
@ -32,6 +32,7 @@ import numpy as np
|
||||
|
||||
import jax
|
||||
|
||||
from jax._src import api
|
||||
from jax._src import api_util
|
||||
from jax._src import compiler
|
||||
from jax._src import config
|
||||
@ -60,6 +61,7 @@ from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.layout import DeviceLocalLayout, AutoLayout, Layout
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.partition_spec import PartitionSpec
|
||||
@ -106,39 +108,67 @@ ShardingSpec = sharding_specs.ShardingSpec
|
||||
def identity(x): return x
|
||||
|
||||
@profiler.annotate_function
|
||||
def shard_args(shardings: Sequence[JSharding], args, canonicalize=True) -> Sequence[xc.ArrayImpl]:
|
||||
def shard_args(shardings: Sequence[JSharding], layouts, args,
|
||||
canonicalize=True) -> Sequence[xc.ArrayImpl]:
|
||||
# Fast path for one argument.
|
||||
if len(args) == 1:
|
||||
arg = args[0]
|
||||
if canonicalize:
|
||||
arg = xla.canonicalize_dtype(arg)
|
||||
return shard_arg_handlers[type(arg)]([arg], shardings)
|
||||
return shard_arg_handlers[type(arg)]([arg], shardings, layouts)
|
||||
|
||||
# type(arg) -> (indices, args, shardings)
|
||||
batches = collections.defaultdict(lambda: ([], [], [])) # type: ignore
|
||||
for i, (arg, sharding) in enumerate(safe_zip(args, shardings)):
|
||||
# type(arg) -> (list[indices], list[args], list[shardings])
|
||||
batches = collections.defaultdict(lambda: ([], [], [], [])) # type: ignore
|
||||
for i, (arg, sharding, layout) in enumerate(safe_zip(args, shardings, layouts)):
|
||||
if canonicalize:
|
||||
arg = xla.canonicalize_dtype(arg)
|
||||
batch = batches[type(arg)]
|
||||
batch[0].append(i)
|
||||
batch[1].append(arg)
|
||||
batch[2].append(sharding)
|
||||
batch[3].append(layout)
|
||||
|
||||
# Call `shard_arg_handlers` per batch and build a flat list of arrays returned
|
||||
# from each call in the same order as `args`. Since `batches` is grouped by
|
||||
# types, we cannot simply flatten the results and we have to use the original
|
||||
# indices to put each array back to its original position.
|
||||
results: list[jax.Array | None] = [None] * len(args)
|
||||
for t, (indices, a, s) in batches.items():
|
||||
outs = shard_arg_handlers[t](a, s)
|
||||
for t, (indices, a, s, l) in batches.items():
|
||||
outs = shard_arg_handlers[t](a, s, l)
|
||||
for i, out in safe_zip(indices, outs):
|
||||
results[i] = out
|
||||
|
||||
assert all(result is not None for result in results)
|
||||
return results
|
||||
|
||||
|
||||
shard_arg_handlers: dict[Any, Callable[[Sequence[Any], Sequence[Any]], Sequence[Any]]] = {}
|
||||
shard_arg_handlers: dict[
|
||||
Any, Callable[[Sequence[Any], Sequence[Any], Sequence[Any]], Sequence[Any]]
|
||||
] = {}
|
||||
|
||||
|
||||
def is_default_layout(curr_layout, sharding, aval):
|
||||
if curr_layout is None or sharding is None:
|
||||
return True
|
||||
if (aval is core.abstract_token or aval.dtype == dtypes.float0 or
|
||||
dtypes.issubdtype(aval.dtype, dtypes.extended)):
|
||||
return True
|
||||
if isinstance(curr_layout, AutoLayout):
|
||||
return False
|
||||
d = sharding._device_assignment[0]
|
||||
shard_shape = sharding.shard_shape(aval.shape)
|
||||
try:
|
||||
# TODO(yashkatariya): Replace this with normal `==` check once CPU supports
|
||||
# int4.
|
||||
return is_user_xla_layout_equal(
|
||||
curr_layout,
|
||||
DeviceLocalLayout.from_pjrt_layout(
|
||||
d.client.get_default_layout(aval.dtype, shard_shape, d)))
|
||||
except xe.XlaRuntimeError as e:
|
||||
msg, *_ = e.args
|
||||
if isinstance(msg, str) and msg.startswith("UNIMPLEMENTED"):
|
||||
return True
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
@lru_cache(maxsize=1024)
|
||||
@ -146,34 +176,37 @@ def _get_replicated_slices(num_addressable_devices: int):
|
||||
return ((slice(None),),) * num_addressable_devices
|
||||
|
||||
|
||||
def _masked_array_error(xs, shardings):
|
||||
def _masked_array_error(xs, shardings, layouts):
|
||||
raise ValueError("numpy masked arrays are not supported as direct inputs to JAX functions. "
|
||||
"Use arr.filled() to convert the value to a standard numpy array.")
|
||||
shard_arg_handlers[np.ma.MaskedArray] = _masked_array_error
|
||||
|
||||
def _shard_array(xs, shardings):
|
||||
def _shard_np_array(xs, shardings, layouts):
|
||||
results = []
|
||||
for x, sharding in safe_zip(xs, shardings):
|
||||
for x, sharding, layout in safe_zip(xs, shardings, layouts):
|
||||
devices = sharding._addressable_device_assignment
|
||||
if x.dtype == dtypes.float0:
|
||||
x = np.zeros(x.shape, dtype=np.dtype(bool))
|
||||
aval = api_util.shaped_abstractify(x)
|
||||
if sharding.is_fully_replicated:
|
||||
shards = [x] * len(devices)
|
||||
if not is_default_layout(layout, sharding, aval):
|
||||
results.append(api.device_put(x, Layout(layout, sharding)))
|
||||
else:
|
||||
indices = tuple(sharding.addressable_devices_indices_map(x.shape).values())
|
||||
shards = [x[i] for i in indices]
|
||||
results.append(batched_device_put(aval, sharding, shards, devices))
|
||||
if sharding.is_fully_replicated:
|
||||
shards = [x] * len(devices)
|
||||
else:
|
||||
indices = tuple(sharding.addressable_devices_indices_map(x.shape).values())
|
||||
shards = [x[i] for i in indices]
|
||||
results.append(batched_device_put(aval, sharding, shards, devices))
|
||||
return results
|
||||
for _t in array_types:
|
||||
shard_arg_handlers[_t] = _shard_array
|
||||
shard_arg_handlers[_t] = _shard_np_array
|
||||
|
||||
def _shard_darray(xs, shardings):
|
||||
return shard_args(shardings, [x._data for x in xs])
|
||||
def _shard_darray(xs, shardings, layouts):
|
||||
return shard_args(shardings, layouts, [x._data for x in xs])
|
||||
shard_arg_handlers[core.DArray] = _shard_darray
|
||||
|
||||
def _shard_mutable_array(xs, shardings):
|
||||
return shard_args(shardings, [x._buf for x in xs])
|
||||
def _shard_mutable_array(xs, shardings, layouts):
|
||||
return shard_args(shardings, layouts, [x._buf for x in xs])
|
||||
shard_arg_handlers[core.MutableArray] = _shard_mutable_array
|
||||
|
||||
def batched_device_put(aval: core.ShapedArray,
|
||||
@ -931,6 +964,7 @@ class UnloadedPmapExecutable:
|
||||
handle_outs = local_avals_to_results_handler(self.local_output_avals,
|
||||
self.output_shardings)
|
||||
handle_args = InputsHandler(self.input_shardings,
|
||||
[None] * len(self.input_shardings),
|
||||
self.compiled.local_devices(), input_indices)
|
||||
execute_fun = ExecuteReplicated(self.compiled, "parallel computation",
|
||||
self.backend, handle_args, handle_outs,
|
||||
@ -1109,12 +1143,15 @@ def _get_pmap_sharding(devices, specs):
|
||||
|
||||
|
||||
class InputsHandler:
|
||||
__slots__ = ("handler", "local_devices", "in_shardings", "input_indices")
|
||||
__slots__ = ("handler", "in_shardings", "in_layouts", "local_devices",
|
||||
"input_indices")
|
||||
|
||||
def __init__(self, in_shardings, local_devices=None, input_indices=None):
|
||||
self.handler = partial(shard_args, in_shardings)
|
||||
self.local_devices = local_devices
|
||||
def __init__(self, in_shardings, in_layouts, local_devices=None,
|
||||
input_indices=None):
|
||||
self.handler = partial(shard_args, in_shardings, in_layouts)
|
||||
self.in_shardings = in_shardings
|
||||
self.in_layouts = in_layouts
|
||||
self.local_devices = local_devices
|
||||
self.input_indices = input_indices
|
||||
|
||||
def __call__(self, input_buffers):
|
||||
@ -1122,8 +1159,9 @@ class InputsHandler:
|
||||
|
||||
def __str__(self):
|
||||
return ("InputsHandler(\n"
|
||||
f"local_devices={self.local_devices},\n"
|
||||
f"in_shardings={self.in_shardings},\n"
|
||||
f"in_layouts={self.in_layouts},\n"
|
||||
f"local_devices={self.local_devices},\n"
|
||||
f"input_indices={self.input_indices})")
|
||||
|
||||
|
||||
@ -1849,7 +1887,7 @@ def _maybe_get_default_layout(arg_layout, jit_in_layout, sharding, aval
|
||||
if is_unspecified_or_auto(sharding):
|
||||
return None
|
||||
# TODO(yashkatariya): Figure out how layouts work with extended dtypes.
|
||||
if dtypes.issubdtype(aval.dtype, dtypes.extended):
|
||||
if aval is core.abstract_token or dtypes.issubdtype(aval.dtype, dtypes.extended):
|
||||
return None
|
||||
if not core.is_constant_shape(aval.shape):
|
||||
return None
|
||||
@ -2505,7 +2543,7 @@ def maybe_recover_user_shardings(
|
||||
|
||||
def is_user_xla_layout_equal(ul: DeviceLocalLayout | AutoLayout,
|
||||
xl: DeviceLocalLayout) -> bool:
|
||||
if isinstance(ul, DeviceLocalLayout) and ul._tiling is None:
|
||||
if isinstance(ul, DeviceLocalLayout) and not ul._tiling:
|
||||
return ul.major_to_minor == xl.major_to_minor
|
||||
else:
|
||||
return ul == xl
|
||||
@ -2742,7 +2780,7 @@ class UnloadedMeshExecutable:
|
||||
pgle_profiler: profiler.PGLEProfiler | None
|
||||
|
||||
def build_unsafe_call(self):
|
||||
handle_args = InputsHandler(self.input_shardings)
|
||||
handle_args = InputsHandler(self.input_shardings, self.in_layouts)
|
||||
handle_outs = global_avals_to_results_handler(
|
||||
self.output_avals, self.output_shardings, self.committed)
|
||||
|
||||
@ -2882,9 +2920,7 @@ class MeshExecutableFastpathData(NamedTuple):
|
||||
out_avals: Sequence[ShapedArray]
|
||||
out_committed: Sequence[bool]
|
||||
kept_var_bitvec: Iterable[bool]
|
||||
# TODO(yashkatariya): Remove once minimum jaxlib version is 0.4.24
|
||||
arg_handler_devices: Sequence[xc.Device]
|
||||
arg_handler_indices: Sequence[tuple[Index | None, ...]]
|
||||
in_device_local_layouts: Sequence[DeviceLocalLayout | None]
|
||||
|
||||
|
||||
def reflatten_outputs_for_dispatch(out_tree, out_flat):
|
||||
@ -2992,18 +3028,36 @@ class MeshExecutable(stages.XlaExecutable):
|
||||
else s
|
||||
for s, a in zip(self._in_shardings, self.in_avals)
|
||||
]
|
||||
in_dlls = get_layouts_for_fasthpath_data(
|
||||
self._in_layouts, in_shardings, self.in_avals)
|
||||
fastpath_data = MeshExecutableFastpathData(
|
||||
self.xla_executable, out_tree_dispatch, in_shardings,
|
||||
self._out_shardings, out_avals, out_committed, kept_var_bitvec,
|
||||
self.unsafe_call.in_handler.local_devices,
|
||||
self.unsafe_call.in_handler.input_indices)
|
||||
in_dlls)
|
||||
else:
|
||||
fastpath_data = None
|
||||
return outs, fastpath_data, False # Do not remove cache entry
|
||||
|
||||
return xc._xla.pjit(
|
||||
self.unsafe_call.name, None, aot_cache_miss, [], [], [],
|
||||
tree_util.dispatch_registry, lambda x, s: shard_args([s], [x])[0])
|
||||
tree_util.dispatch_registry, cc_shard_arg)
|
||||
|
||||
if xla_extension_version < 282:
|
||||
def cc_shard_arg(x, sharding):
|
||||
return shard_args([sharding], [None], [x])[0]
|
||||
else:
|
||||
def cc_shard_arg(x, sharding, layout): # type: ignore
|
||||
return shard_args([sharding], [layout], [x])[0]
|
||||
|
||||
|
||||
def get_layouts_for_fasthpath_data(in_layouts, in_shardings, in_avals):
|
||||
in_dlls = []
|
||||
for l, s, a in zip(in_layouts, in_shardings, in_avals):
|
||||
if is_default_layout(l, s, a):
|
||||
in_dlls.append(None)
|
||||
else:
|
||||
in_dlls.append(l)
|
||||
return in_dlls
|
||||
|
||||
|
||||
def check_arg_avals_for_call(ref_avals, arg_avals,
|
||||
|
@ -688,7 +688,7 @@ def _maybe_put(x):
|
||||
aval = shaped_abstractify(x)
|
||||
s = jax.sharding.SingleDeviceSharding(jax.local_devices(backend='cpu')[0])
|
||||
result_handler = pxla.global_aval_to_result_handler(aval, s, False)
|
||||
return result_handler(pxla.shard_args([s], [x]))
|
||||
return result_handler(pxla.shard_args([s], [None], [x]))
|
||||
else:
|
||||
return x
|
||||
|
||||
|
@ -69,7 +69,7 @@ class DeviceLocalLayout:
|
||||
self._tiling == other._tiling and
|
||||
self._sub_byte_element_size_in_bits == other._sub_byte_element_size_in_bits)
|
||||
|
||||
def _to_xla_layout(self, dtype) -> str:
|
||||
def _to_xla_layout(self, dtype) -> xc.Layout:
|
||||
if self._tiling is None:
|
||||
xla_layout = xc.Layout(self.major_to_minor[::-1])
|
||||
else:
|
||||
@ -81,7 +81,7 @@ class DeviceLocalLayout:
|
||||
sub_byte_size = 0
|
||||
xla_layout = xc.Layout(self.major_to_minor[::-1], self._tiling,
|
||||
sub_byte_size)
|
||||
return str(xla_layout)
|
||||
return xla_layout
|
||||
|
||||
def check_compatible_aval(self, aval_shape: Shape):
|
||||
if len(self.major_to_minor) != len(aval_shape):
|
||||
|
@ -279,11 +279,12 @@ def _get_fastpath_data(
|
||||
else s
|
||||
for s, a in zip(executable._in_shardings, executable.in_avals)
|
||||
]
|
||||
in_dlls = pxla.get_layouts_for_fasthpath_data(
|
||||
executable._in_layouts, in_shardings, executable.in_avals)
|
||||
fastpath_data = pxla.MeshExecutableFastpathData(
|
||||
executable.xla_executable, out_tree, in_shardings,
|
||||
executable._out_shardings, out_avals, out_committed, kept_var_bitvec,
|
||||
executable.unsafe_call.in_handler.local_devices,
|
||||
executable.unsafe_call.in_handler.input_indices)
|
||||
in_dlls)
|
||||
else:
|
||||
fastpath_data = None
|
||||
return fastpath_data
|
||||
@ -302,9 +303,7 @@ def _read_most_recent_pjit_call_executable(jaxpr):
|
||||
|
||||
|
||||
def _read_pgle_profiler(jaxpr):
|
||||
return _most_recent_pjit_call_executable.weak_pgle_profiler_dict.get(
|
||||
jaxpr, None
|
||||
)
|
||||
return _most_recent_pjit_call_executable.weak_pgle_profiler_dict.get(jaxpr, None)
|
||||
|
||||
def _cpp_pjit_evict_fn(self):
|
||||
self._clear_cache()
|
||||
@ -343,8 +342,7 @@ def _cpp_pjit(fun: Callable, jit_info: PjitInfo):
|
||||
cpp_pjit_f = xc._xla.pjit(
|
||||
fun_name(fun),
|
||||
fun, cache_miss, jit_info.static_argnums, jit_info.static_argnames,
|
||||
jit_info.donate_argnums, tree_util.dispatch_registry,
|
||||
lambda x, sharding: pxla.shard_args([sharding], [x])[0],
|
||||
jit_info.donate_argnums, tree_util.dispatch_registry, pxla.cc_shard_arg,
|
||||
_get_cpp_global_cache(jit_info.has_explicit_sharding))
|
||||
|
||||
cpp_pjitted_f = wraps(fun)(cpp_pjit_f)
|
||||
@ -1729,8 +1727,7 @@ def _pjit_call_impl(*args, jaxpr,
|
||||
in_shardings, out_shardings, None, None)
|
||||
return xc._xla.pjit(
|
||||
name, f, call_impl_cache_miss, [], [], donated_argnums,
|
||||
tree_util.dispatch_registry,
|
||||
lambda x, sharding: pxla.shard_args([sharding], [x])[0],
|
||||
tree_util.dispatch_registry, pxla.cc_shard_arg,
|
||||
_get_cpp_global_cache(has_explicit_sharding))(*args)
|
||||
|
||||
pjit_p.def_impl(_pjit_call_impl)
|
||||
|
@ -466,11 +466,12 @@ xla.pytype_aval_mappings[PRNGKeyArray] = lambda x: x.aval
|
||||
xla.canonicalize_dtype_handlers[PRNGKeyArray] = lambda x: x
|
||||
|
||||
|
||||
def key_array_shard_arg_handler(xs: Sequence[PRNGKeyArray], shardings):
|
||||
def key_array_shard_arg_handler(xs: Sequence[PRNGKeyArray], shardings, layouts):
|
||||
arrs = [x._base_array for x in xs]
|
||||
phys_shardings = [physical_sharding(x.aval, sharding)
|
||||
for x, sharding in zip(xs, shardings)]
|
||||
return pxla.shard_args(phys_shardings, arrs)
|
||||
# TODO(yashkatariya): `layouts` should be converted to physical layouts.
|
||||
return pxla.shard_args(phys_shardings, layouts, arrs)
|
||||
|
||||
|
||||
pxla.shard_arg_handlers[PRNGKeyArray] = key_array_shard_arg_handler
|
||||
|
@ -3373,7 +3373,7 @@ class FooArray:
|
||||
size = property(lambda self: self.data.size // 2)
|
||||
ndim = property(lambda self: self.data.ndim - 1)
|
||||
|
||||
def shard_foo_array_handler(xs, shardings):
|
||||
def shard_foo_array_handler(xs, shardings, layouts):
|
||||
results = []
|
||||
for x, sharding in safe_zip(xs, shardings):
|
||||
device, = sharding._addressable_device_assignment
|
||||
|
@ -500,6 +500,29 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
'Layout passed to jit does not match the layout on the respective arg'):
|
||||
g(arr)
|
||||
|
||||
def test_in_layouts_jit_jnp_input(self):
|
||||
major_last_layout = DLL(major_to_minor=(1, 0))
|
||||
sharding = jax.sharding.SingleDeviceSharding(jax.devices()[0])
|
||||
|
||||
f = jax.jit(lambda x: x + 1,
|
||||
in_shardings=Layout(major_last_layout, sharding))
|
||||
|
||||
arr = jnp.arange(8 * 128).reshape(8, 128)
|
||||
out = f(arr)
|
||||
self.assertArraysEqual(out, arr + 1)
|
||||
|
||||
# cpp dispatch should call into shard_args from cpp.
|
||||
out2 = f(arr)
|
||||
self.assertArraysEqual(out2, arr + 1)
|
||||
|
||||
np_inp = np.arange(8 * 128).reshape(8, 128)
|
||||
out3 = f(np_inp)
|
||||
self.assertArraysEqual(out3, np_inp + 1)
|
||||
|
||||
# cpp dispatch should call into shard_args from cpp.
|
||||
out4 = f(np_inp)
|
||||
self.assertArraysEqual(out4, np_inp + 1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
@ -3015,7 +3015,7 @@ class ShardArgsTest(jtu.JaxTestCase):
|
||||
x = np.arange(math.prod(shape)).reshape(shape)
|
||||
arg = make_arg(x)
|
||||
sharding = jax.sharding.PmapSharding(jax.devices()[:nshards], spec)
|
||||
results = pxla.shard_args([sharding], [arg])
|
||||
results = pxla.shard_args([sharding], [None], [arg])
|
||||
self.assertEqual(len(results), 1)
|
||||
if isinstance(results[0], array.ArrayImpl):
|
||||
bufs = results[0]._arrays
|
||||
|
Loading…
x
Reference in New Issue
Block a user