mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Delete some dead code that pertained to sharded_jit.
sharded_jit is long gone. PiperOrigin-RevId: 523711890
This commit is contained in:
parent
3ca7d67e8d
commit
33acdc0e40
@ -31,7 +31,7 @@ import numpy as np
|
||||
|
||||
import jax
|
||||
from jax.errors import JAXTypeError
|
||||
from jax.tree_util import tree_flatten, tree_map
|
||||
from jax.tree_util import tree_map
|
||||
|
||||
from jax._src import api_util
|
||||
from jax._src import core
|
||||
@ -712,7 +712,7 @@ class ParallelCallableInfo:
|
||||
|
||||
class ShardInfo(NamedTuple):
|
||||
sharded_avals: Sequence[core.AbstractValue]
|
||||
out_sharded_avals: Sequence[core.AbstractValue]
|
||||
out_sharded_avals: Sequence[core.ShapedArray]
|
||||
global_sharded_avals: Sequence[core.AbstractValue]
|
||||
num_local_shards: int
|
||||
num_global_shards: int
|
||||
@ -761,25 +761,16 @@ def stage_parallel_callable(
|
||||
check_multihost_collective_allowlist(jaxpr)
|
||||
|
||||
replicas = find_replicas(jaxpr, pci.axis_size, pci.global_axis_size)
|
||||
parts = find_partitions(jaxpr)
|
||||
|
||||
num_local_shards = replicas.num_local_replicas * parts.local_num_partitions
|
||||
num_global_shards = replicas.num_global_replicas * parts.num_partitions
|
||||
num_local_shards = replicas.num_local_replicas
|
||||
num_global_shards = replicas.num_global_replicas
|
||||
|
||||
shards = ShardInfo(
|
||||
sharded_avals, out_sharded_avals, sharded_avals,
|
||||
num_local_shards, num_global_shards)
|
||||
|
||||
return jaxpr, consts, replicas, parts, shards
|
||||
return jaxpr, consts, replicas, shards
|
||||
|
||||
|
||||
def _shardings_to_mlir_shardings(
|
||||
shardings: Optional[Sequence[PartitionsOrReplicated]]
|
||||
) -> Optional[Sequence[Optional[xc.OpSharding]]]:
|
||||
if shardings is None:
|
||||
return None
|
||||
return [xla.sharding_to_proto(s) for s in shardings]
|
||||
|
||||
@profiler.annotate_function
|
||||
def lower_parallel_callable(
|
||||
fun: lu.WrappedFun,
|
||||
@ -817,8 +808,7 @@ def lower_parallel_callable(
|
||||
if xb.process_count(backend) > 1:
|
||||
if devices:
|
||||
# This allows each host in a multi-host pmap to run on a different number
|
||||
# of devices, but precludes nested sharding (i.e. inner pmaps or
|
||||
# sharded_jits).
|
||||
# of devices, but precludes nested sharding (i.e. inner pmaps).
|
||||
no_nested_sharding = True
|
||||
else:
|
||||
# This assumes all hosts run on the same number of devices. We make sure
|
||||
@ -830,18 +820,12 @@ def lower_parallel_callable(
|
||||
pci = ParallelCallableInfo(
|
||||
name, backend, axis_name, axis_size, global_axis_size, devices,
|
||||
in_axes, out_axes_thunk, avals)
|
||||
jaxpr, consts, replicas, parts, shards = stage_parallel_callable(pci, fun)
|
||||
jaxpr, consts, replicas, shards = stage_parallel_callable(pci, fun)
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
logger.debug("sharded_avals: %s", shards.sharded_avals)
|
||||
logger.debug("global_sharded_avals: %s", shards.global_sharded_avals)
|
||||
logger.debug("num_replicas: %d num_local_replicas: %d",
|
||||
replicas.num_global_replicas, replicas.num_local_replicas)
|
||||
logger.debug("num_partitions: %d local_num_partitions: %d",
|
||||
parts.num_partitions, parts.local_num_partitions)
|
||||
logger.debug("arg_parts: %s", parts.arg_parts)
|
||||
logger.debug("local_arg_parts: %s", parts.local_arg_parts)
|
||||
logger.debug("out_parts: %s", parts.out_parts)
|
||||
logger.debug("local_out_parts: %s", parts.local_out_parts)
|
||||
logger.debug("devices: %s", devices)
|
||||
logger.debug("local_devices: %s", pci.local_devices)
|
||||
|
||||
@ -858,24 +842,21 @@ def lower_parallel_callable(
|
||||
f"On multi-host platforms, pmapped functions must run across all "
|
||||
f"devices, i.e. num_replicas * num_partitions should equal the "
|
||||
f"number of local devices. Got "
|
||||
f"num_replicas={replicas.num_local_replicas}, "
|
||||
f"num_partitions={parts.num_partitions}, and "
|
||||
f"num_replicas={replicas.num_local_replicas}, and "
|
||||
f"num_local_devices={xb.local_device_count(backend)}")
|
||||
|
||||
if no_nested_sharding and (
|
||||
replicas.jaxpr_replicas > 1 or parts.num_partitions > 1):
|
||||
if no_nested_sharding and replicas.jaxpr_replicas > 1:
|
||||
raise ValueError(
|
||||
f"On multi-host platforms, pmapped functions that both have `devices` "
|
||||
f"specified and contain an inner_pmap or sharded_jit must specify an "
|
||||
f"specified and contain an inner_pmap must specify an "
|
||||
f"`axis_size` (or remove the `devices` argument). Got nested_replicas="
|
||||
f"{replicas.jaxpr_replicas} and nested_partitions={parts.num_partitions}")
|
||||
f"{replicas.jaxpr_replicas}")
|
||||
|
||||
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
|
||||
logger.log(log_priority,
|
||||
"Compiling %s (%d) for %d devices with args %s. (num_replicas=%d"
|
||||
" num_partitions=%d)", fun.__name__, id(fun),
|
||||
shards.num_global_shards, avals, replicas.num_global_replicas,
|
||||
parts.num_partitions)
|
||||
"Compiling %s (%d) for %d devices with args %s. (num_replicas=%d)",
|
||||
fun.__name__, id(fun),
|
||||
shards.num_global_shards, avals, replicas.num_global_replicas)
|
||||
|
||||
axis_env = sharding_impls.AxisEnv(
|
||||
replicas.num_global_replicas, (axis_name,), (global_axis_size,))
|
||||
@ -903,14 +884,14 @@ def lower_parallel_callable(
|
||||
name_stack,
|
||||
donated_invars,
|
||||
replicated_args=replicated_args,
|
||||
arg_shardings=_shardings_to_mlir_shardings(parts.arg_parts),
|
||||
result_shardings=_shardings_to_mlir_shardings(parts.out_parts),
|
||||
arg_shardings=None,
|
||||
result_shardings=None,
|
||||
arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names,
|
||||
result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths)
|
||||
module, keepalive, host_callbacks = (
|
||||
lowering_result.module, lowering_result.keepalive,
|
||||
lowering_result.host_callbacks)
|
||||
return PmapComputation(module, pci=pci, replicas=replicas, parts=parts,
|
||||
return PmapComputation(module, pci=pci, replicas=replicas,
|
||||
shards=shards, tuple_args=tuple_args,
|
||||
unordered_effects=unordered_effects,
|
||||
ordered_effects=ordered_effects,
|
||||
@ -951,6 +932,9 @@ class PmapComputation(stages.XlaLowering):
|
||||
return executable
|
||||
return self._executable
|
||||
|
||||
def _cast_to_shaped_array(aval: core.AbstractValue) -> ShapedArray:
|
||||
assert isinstance(aval, ShapedArray), aval
|
||||
return cast(ShapedArray, aval)
|
||||
|
||||
@dataclasses.dataclass
|
||||
class UnloadedPmapExecutable:
|
||||
@ -969,7 +953,6 @@ class UnloadedPmapExecutable:
|
||||
def from_hlo(xla_computation,
|
||||
pci: ParallelCallableInfo,
|
||||
replicas: ReplicaInfo,
|
||||
parts: PartitionInfo,
|
||||
shards: ShardInfo,
|
||||
tuple_args: bool,
|
||||
unordered_effects: List[core.Effect],
|
||||
@ -981,11 +964,10 @@ class UnloadedPmapExecutable:
|
||||
if devices is None:
|
||||
if shards.num_global_shards > xb.device_count(pci.backend):
|
||||
msg = ("compiling computation that requires {} logical devices, but only {} XLA "
|
||||
"devices are available (num_replicas={}, num_partitions={})")
|
||||
"devices are available (num_replicas={})")
|
||||
raise ValueError(msg.format(shards.num_global_shards,
|
||||
xb.device_count(pci.backend),
|
||||
replicas.num_global_replicas,
|
||||
parts.num_partitions))
|
||||
replicas.num_global_replicas))
|
||||
# On a single host, we simply grab the first N devices from jax.devices().
|
||||
# In the single host case, we want the default device order of pmap to
|
||||
# match jax.devices().
|
||||
@ -1023,16 +1005,14 @@ class UnloadedPmapExecutable:
|
||||
# get_default_device_assignment() returns 2D assignment, caller may have
|
||||
# provided 1D list of devices).
|
||||
# Convert to 2D in case it's 1D and we have > 1 partitions.
|
||||
num_partitions = 1
|
||||
device_assignment: np.ndarray = np.array(devices).reshape(
|
||||
(replicas.num_global_replicas, parts.num_partitions))
|
||||
# TODO(b/162356737): Enabling SPMD partitioning causes issues with some
|
||||
# non-partitioned workloads, so disable unless needed.
|
||||
use_spmd_partitioning = parts.num_partitions > 1
|
||||
(replicas.num_global_replicas, num_partitions))
|
||||
compile_options = xb.get_compile_options(
|
||||
num_replicas=replicas.num_global_replicas,
|
||||
num_partitions=parts.num_partitions,
|
||||
num_partitions=num_partitions,
|
||||
device_assignment=device_assignment,
|
||||
use_spmd_partitioning=use_spmd_partitioning,
|
||||
use_spmd_partitioning=False,
|
||||
env_options_overrides=compiler_options,
|
||||
)
|
||||
compile_options.parameter_is_tupled_arguments = tuple_args
|
||||
@ -1042,34 +1022,24 @@ class UnloadedPmapExecutable:
|
||||
d for d in device_assignment.flat if d.process_index == process_index
|
||||
])
|
||||
|
||||
local_arg_parts_ = parts.local_arg_parts or [None] * len(pci.avals)
|
||||
input_sharding_specs = [
|
||||
sharding_specs.pmap_sharding_spec(
|
||||
replicas.num_local_replicas, pci.axis_size,
|
||||
parts.local_num_partitions, arg_parts,
|
||||
cast(ShapedArray, aval).shape, in_axis)
|
||||
for aval, arg_parts, in_axis in safe_zip(
|
||||
shards.sharded_avals, local_arg_parts_, pci.in_axes)]
|
||||
in_shardings = _get_pmap_sharding(local_device_assignment, input_sharding_specs)
|
||||
nouts = len(shards.out_sharded_avals)
|
||||
for aval, in_axis in safe_zip(shards.sharded_avals, pci.in_axes)]
|
||||
in_shardings = _get_pmap_sharding(local_device_assignment,
|
||||
input_sharding_specs)
|
||||
|
||||
out_parts = (None,) * nouts if parts.out_parts is None else parts.out_parts
|
||||
local_out_parts = (None,) * nouts if parts.local_out_parts is None else parts.local_out_parts
|
||||
|
||||
local_out_avals = [
|
||||
get_local_aval(aval, parts, lparts)
|
||||
for aval, parts, lparts
|
||||
in safe_zip(shards.out_sharded_avals, out_parts, local_out_parts)]
|
||||
local_unmapped_avals = [
|
||||
core.unmapped_aval(pci.axis_size, pci.axis_name, out_axis, aval)
|
||||
_cast_to_shaped_array(
|
||||
core.unmapped_aval(pci.axis_size, pci.axis_name, out_axis, aval))
|
||||
if out_axis is not None else aval
|
||||
for aval, out_axis in safe_zip(local_out_avals, pci.out_axes)]
|
||||
for aval, out_axis in safe_zip(shards.out_sharded_avals, pci.out_axes)]
|
||||
out_specs = [
|
||||
sharding_specs.pmap_sharding_spec(
|
||||
replicas.num_local_replicas, pci.axis_size,
|
||||
parts.local_num_partitions, out_parts, aval.shape, out_axis)
|
||||
for out_parts, aval, out_axis in safe_zip(
|
||||
local_out_parts, local_out_avals, pci.out_axes)]
|
||||
replicas.num_local_replicas, pci.axis_size, aval.shape, out_axis)
|
||||
for aval, out_axis in safe_zip(
|
||||
shards.out_sharded_avals, pci.out_axes)]
|
||||
out_shardings = _get_pmap_sharding(local_device_assignment, out_specs)
|
||||
|
||||
if hasattr(pci.backend, "compile_replicated"):
|
||||
@ -1193,124 +1163,6 @@ def check_multihost_collective_allowlist(jaxpr):
|
||||
raise TypeError(msg.format(", ".join(map(str, bad_collectives))))
|
||||
|
||||
|
||||
PartitionsOrReplicated = Optional[Tuple[int, ...]]
|
||||
|
||||
class PartitionInfo(NamedTuple):
|
||||
arg_parts: Optional[Tuple[PartitionsOrReplicated, ...]]
|
||||
out_parts: Optional[Tuple[PartitionsOrReplicated, ...]]
|
||||
num_partitions: int
|
||||
local_arg_parts: Optional[Tuple[PartitionsOrReplicated, ...]]
|
||||
local_out_parts: Optional[Tuple[PartitionsOrReplicated, ...]]
|
||||
local_num_partitions: Optional[int]
|
||||
|
||||
def _find_partitions(jaxpr):
|
||||
"""Returns (in_partitions, out_partitions, num_partitions, local_in_parts,
|
||||
local_out_parts, local_num_partitions).
|
||||
"""
|
||||
for eqn in jaxpr.eqns:
|
||||
if eqn.primitive.name == "sharded_call":
|
||||
if len(jaxpr.eqns) > 1:
|
||||
raise NotImplementedError(
|
||||
"pmap of sharded_jit + non-sharded operations not yet implemented.")
|
||||
num_partitions = reconcile_num_partitions(eqn.params["call_jaxpr"],
|
||||
eqn.params["nparts"])
|
||||
return (eqn.params["in_parts"],
|
||||
eqn.params["out_parts_thunk"](),
|
||||
num_partitions,
|
||||
eqn.params["local_in_parts"],
|
||||
eqn.params["local_out_parts_thunk"](),
|
||||
eqn.params["local_nparts"])
|
||||
return None, None, 1, None, None, None
|
||||
|
||||
def find_partitions(jaxpr) -> PartitionInfo:
|
||||
(arg_parts, out_parts, num_partitions, local_arg_parts, local_out_parts,
|
||||
local_num_partitions) = _find_partitions(jaxpr)
|
||||
|
||||
if local_num_partitions is None:
|
||||
local_num_partitions = num_partitions
|
||||
if local_arg_parts is None:
|
||||
local_arg_parts = arg_parts
|
||||
if local_out_parts is None:
|
||||
local_out_parts = out_parts
|
||||
|
||||
return PartitionInfo(arg_parts, out_parts, num_partitions,
|
||||
local_arg_parts, local_out_parts, local_num_partitions)
|
||||
|
||||
|
||||
def reconcile_num_partitions(jaxpr, outer_num_parts: Optional[int]):
|
||||
"""Returns the total number of partitions to use.
|
||||
|
||||
Validates that any inner partitioning matches outer_num_parts if provided, and
|
||||
returns the number of partitions to use based on outer_num_parts and any inner
|
||||
partitioning.
|
||||
"""
|
||||
inner_num_parts = _inner_partitions(jaxpr, outer_num_parts)
|
||||
if outer_num_parts is None and inner_num_parts is None:
|
||||
# No partitions specified anywhere, everything is replicated.
|
||||
return 1
|
||||
if outer_num_parts is None:
|
||||
return inner_num_parts
|
||||
return outer_num_parts
|
||||
|
||||
|
||||
def _inner_partitions(jaxpr, expected_num_parts: Optional[int]):
|
||||
"""Returns the total number of partitions from PartitionSpecs inside `jaxpr`.
|
||||
|
||||
Also validates that this number matches `expected_num_parts` if provided.
|
||||
"""
|
||||
for eqn in jaxpr.eqns:
|
||||
if eqn.primitive.name in ["sharding_constraint", "infeed"]:
|
||||
parts = eqn.params["partitions"]
|
||||
nparts = get_num_partitions(parts)
|
||||
if expected_num_parts is None:
|
||||
expected_num_parts = nparts
|
||||
elif nparts is not None and nparts != expected_num_parts:
|
||||
# TODO(skye): raise this error as we trace the jaxpr
|
||||
raise ValueError(
|
||||
f"with_sharding_constraint with partitions={parts} "
|
||||
f"(total partitions: {nparts}) doesn't match expected number of "
|
||||
f"partitions: {expected_num_parts}. If these partitions look "
|
||||
f"right, check outer sharded_jit and/or other "
|
||||
f"with_sharding_constraint calls.")
|
||||
else:
|
||||
for subjaxpr in core.jaxprs_in_params(eqn.params):
|
||||
expected_num_parts = _inner_partitions(subjaxpr, expected_num_parts)
|
||||
return expected_num_parts
|
||||
|
||||
|
||||
def get_num_partitions(*partitions):
|
||||
partition_specs = tree_flatten(partitions)[0]
|
||||
if len(partition_specs) == 0:
|
||||
# Everything is specified as replicated (all Nones).
|
||||
return None
|
||||
num_partitions_set = {np.prod(spec) for spec in partition_specs}
|
||||
if len(num_partitions_set) > 1:
|
||||
raise ValueError(
|
||||
f"All partition specs must use the same number of total partitions, "
|
||||
f"got {partitions}, with distinct number of partitions "
|
||||
f"{num_partitions_set} (the total number of partitions is the product "
|
||||
f"of a partition spec)")
|
||||
assert len(num_partitions_set) == 1
|
||||
return num_partitions_set.pop()
|
||||
|
||||
|
||||
def get_local_aval(global_aval, global_parts: PartitionsOrReplicated,
|
||||
local_parts: PartitionsOrReplicated):
|
||||
if global_parts is None:
|
||||
return global_aval
|
||||
assert local_parts is not None
|
||||
local_shape = [_safe_div(dim, _safe_div(ngparts, nlparts))
|
||||
for dim, ngparts, nlparts
|
||||
in safe_zip(global_aval.shape, global_parts, local_parts)]
|
||||
return global_aval.update(shape=local_shape)
|
||||
|
||||
|
||||
def _safe_div(x, y):
|
||||
result, ragged = divmod(x, y)
|
||||
assert not ragged, f"{x} % {y} != 0"
|
||||
return result
|
||||
|
||||
|
||||
class InputsHandler:
|
||||
__slots__ = ("handler", "local_devices", "in_shardings", "input_indices")
|
||||
|
||||
@ -1413,7 +1265,7 @@ def replicate(val, axis_size, nrep, devices=None, backend=None, in_axis=0):
|
||||
replicated_aval = aval
|
||||
# TODO(skye): figure out how partitioning should work here
|
||||
sharding_spec = sharding_specs.pmap_sharding_spec(
|
||||
nrep, axis_size, 1, None, aval.shape, in_axis)
|
||||
nrep, axis_size, aval.shape, in_axis)
|
||||
|
||||
buf = jax.device_put(val, devices[0])
|
||||
sharding = sharding_impls.PmapSharding(
|
||||
@ -3144,7 +2996,6 @@ def get_array_mapping(pspec: PartitionSpec) -> ArrayMappingOrAutoOrUnspecified:
|
||||
|
||||
_forbidden_primitives = {
|
||||
'xla_pmap': 'pmap',
|
||||
'sharded_call': 'sharded_jit',
|
||||
}
|
||||
def _sanitize_mesh_jaxpr(jaxpr):
|
||||
if isinstance(jaxpr, core.ClosedJaxpr):
|
||||
@ -3246,5 +3097,7 @@ def _create_pmap_sharding_spec(aval, sharded_dim=0, sharded_dim_size=None):
|
||||
|
||||
def _pmap_sharding_spec(nrep, axis_size, npart, parts,
|
||||
sharded_aval, map_axis: Optional[int]) -> ShardingSpec:
|
||||
return sharding_specs.pmap_sharding_spec(nrep, axis_size, npart, parts,
|
||||
sharded_aval.shape, map_axis)
|
||||
assert npart == 1, npart
|
||||
assert parts is None, parts
|
||||
return sharding_specs.pmap_sharding_spec(
|
||||
nrep, axis_size, sharded_aval.shape, map_axis)
|
||||
|
@ -254,22 +254,6 @@ def spec_to_indices(shape: Sequence[int],
|
||||
return tuple(spec.indices(shape).flat) # type: ignore
|
||||
|
||||
|
||||
def partitioned_sharding_spec(num_partitions: int,
|
||||
partitions: Optional[Sequence[int]],
|
||||
shape: Sequence[int]) -> ShardingSpec:
|
||||
if partitions is None:
|
||||
maybe_replicate = () if num_partitions == 1 else (Replicated(num_partitions),)
|
||||
return ShardingSpec(
|
||||
sharding=[_UNSHARDED_INSTANCE] * len(shape),
|
||||
mesh_mapping=maybe_replicate)
|
||||
else:
|
||||
assert len(partitions) == len(shape)
|
||||
return ShardingSpec(
|
||||
# Chunked expects a list of integers
|
||||
sharding=map(Chunked, [[x] for x in partitions]),
|
||||
mesh_mapping=map(ShardedAxis, range(len(partitions))))
|
||||
|
||||
|
||||
def make_sharding_spec(axis_sizes, mesh_axis_pos, num_dimensions, aval_axes):
|
||||
mesh_mapping = [Replicated(axis_size) for axis_size in axis_sizes.values()]
|
||||
sharding = [_UNSHARDED_INSTANCE] * num_dimensions
|
||||
@ -292,15 +276,12 @@ def new_mesh_sharding_specs(axis_sizes, axis_names):
|
||||
mesh_axis_pos = {name: i for i, name in enumerate(axis_names)}
|
||||
return functools.partial(make_sharding_spec, axis_sizes, mesh_axis_pos)
|
||||
|
||||
def pmap_sharding_spec(nrep, axis_size, npart, parts,
|
||||
sharded_shape: Sequence[int],
|
||||
def pmap_sharding_spec(nrep, axis_size, sharded_shape: Sequence[int],
|
||||
map_axis: Optional[int]) -> ShardingSpec:
|
||||
"""Sharding spec for arguments or results of a pmap.
|
||||
Args:
|
||||
nrep: number of local XLA replicas (product of local axis sizes)
|
||||
axis_size: local axis size for outer pmap
|
||||
npart: total number of XLA partitions (required by sharded_jit calls)
|
||||
parts: the partitioning of the value or None
|
||||
sharded_aval: the aval of the value inside the outer pmap, an instance of
|
||||
a ShapedArray.
|
||||
map_axis: the axis along which the value is mapped in the outer pmap
|
||||
@ -309,8 +290,8 @@ def pmap_sharding_spec(nrep, axis_size, npart, parts,
|
||||
"""
|
||||
replication_factor, ragged = divmod(nrep, axis_size)
|
||||
assert not ragged
|
||||
# get the sharding spec from inner sharded_jits as if we weren't in a pmap
|
||||
pspec = partitioned_sharding_spec(npart, parts, sharded_shape)
|
||||
pspec = ShardingSpec(sharding=[_UNSHARDED_INSTANCE] * len(sharded_shape),
|
||||
mesh_mapping=())
|
||||
maybe_replicate = () if replication_factor == 1 else (Replicated(replication_factor),)
|
||||
if map_axis is not None:
|
||||
sharded_in_axis = sum(not isinstance(s, NoSharding) for s in pspec.sharding[:map_axis])
|
||||
@ -342,5 +323,5 @@ def create_pmap_sharding_spec(shape: Tuple[int, ...], sharded_dim: int = 0,
|
||||
assert sharded_dim_size is not None
|
||||
sharded_shape = shape
|
||||
|
||||
return pmap_sharding_spec(sharded_dim_size, sharded_dim_size, 1, None,
|
||||
sharded_shape, sharded_dim)
|
||||
return pmap_sharding_spec(sharded_dim_size, sharded_dim_size, sharded_shape,
|
||||
sharded_dim)
|
||||
|
@ -3063,8 +3063,10 @@ def _custom_lin(*args: TfVal, **_) -> Sequence[TfVal]:
|
||||
tf_impl[ad.custom_lin_p] = _custom_lin
|
||||
|
||||
|
||||
PartitionsOrReplicated = Optional[Tuple[int, ...]]
|
||||
|
||||
def split_to_logical_devices(tensor: TfVal,
|
||||
partition_dimensions: pxla.PartitionsOrReplicated):
|
||||
partition_dimensions: PartitionsOrReplicated):
|
||||
"""Like TPUMPStrategy.experimental_split_to_logical_devices.
|
||||
|
||||
For jax2tf purposes we want to avoid needing to thread the `strategy` object
|
||||
|
@ -192,8 +192,7 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
|
||||
for p in all_primitives:
|
||||
if p.name == "axis_index":
|
||||
continue
|
||||
# TODO: remove once we delete sharded_jit.py
|
||||
if p.name in ["sharded_call", "sharding_constraint"]:
|
||||
if p.name == "sharding_constraint":
|
||||
continue
|
||||
# TODO: Remove once tensorflow is 2.10.0 everywhere.
|
||||
if p.name == "optimization_barrier":
|
||||
|
@ -24,8 +24,6 @@ from jax._src.interpreters.pxla import (
|
||||
MeshDimAssignment as MeshDimAssignment,
|
||||
MeshExecutable as MeshExecutable,
|
||||
ParallelCallableInfo as ParallelCallableInfo,
|
||||
PartitionInfo as PartitionInfo,
|
||||
PartitionsOrReplicated as PartitionsOrReplicated,
|
||||
PmapComputation as PmapComputation,
|
||||
PmapExecutable as PmapExecutable,
|
||||
PxlaResultHandler as PxlaResultHandler,
|
||||
@ -45,11 +43,8 @@ from jax._src.interpreters.pxla import (
|
||||
array_types as array_types,
|
||||
custom_resource_typing_rules as custom_resource_typing_rules,
|
||||
device_put as _deprecated_device_put,
|
||||
find_partitions as find_partitions,
|
||||
find_replicas as find_replicas,
|
||||
full_to_shard_p as full_to_shard_p,
|
||||
get_local_aval as get_local_aval,
|
||||
get_num_partitions as get_num_partitions,
|
||||
global_aval_to_result_handler as global_aval_to_result_handler,
|
||||
global_avals_to_results_handler as global_avals_to_results_handler,
|
||||
global_result_handlers as global_result_handlers,
|
||||
@ -63,7 +58,6 @@ from jax._src.interpreters.pxla import (
|
||||
mesh_sharding_specs as mesh_sharding_specs,
|
||||
multi_host_supported_collectives as multi_host_supported_collectives,
|
||||
parallel_callable as parallel_callable,
|
||||
reconcile_num_partitions as reconcile_num_partitions,
|
||||
replicate as replicate,
|
||||
resource_typecheck as resource_typecheck,
|
||||
shard_arg_handlers as shard_arg_handlers,
|
||||
|
Loading…
x
Reference in New Issue
Block a user