mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Rename jax.sharding.OpShardingSharding
to jax.sharding.GSPMDSharding
. jax.sharding.OpShardingSharding
will be removed in 3 months from Feb 17, 2023.
PiperOrigin-RevId: 510556189
This commit is contained in:
parent
f7734fd6a4
commit
0ffdeb3de2
@ -8,6 +8,10 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
|
||||
## jax 0.4.5
|
||||
|
||||
* Deprecations
|
||||
* `jax.sharding.OpShardingSharding` has been renamed to `jax.sharding.GSPMDSharding`.
|
||||
`jax.sharding.OpShardingSharding` will be removed in 3 months from Feb 17, 2023.
|
||||
|
||||
## jaxlib 0.4.5
|
||||
|
||||
## jax 0.4.4 (Feb 16, 2023)
|
||||
|
@ -655,7 +655,7 @@ with mesh:
|
||||
c:bf16[32,512,512] d:bf16[512,512] = pjit[
|
||||
donated_invars=(False, False)
|
||||
in_positional_semantics=(<_PositionalSemantics.GLOBAL: 1>, <_PositionalSemantics.GLOBAL: 1>)
|
||||
in_shardings=(OpShardingSharding({devices=[8,1,1]0,1,2,3,4,5,6,7}), OpShardingSharding({replicated}))
|
||||
in_shardings=(GSPMDSharding({devices=[8,1,1]0,1,2,3,4,5,6,7}), GSPMDSharding({replicated}))
|
||||
jaxpr={ lambda ; e:bf16[32,512,512] f:bf16[512,512]. let
|
||||
g:bf16[8,4,512,512] = reshape[
|
||||
dimensions=None
|
||||
@ -734,7 +734,7 @@ with mesh:
|
||||
in (bn, bf) }
|
||||
name=<unnamed function>
|
||||
out_positional_semantics=_PositionalSemantics.GLOBAL
|
||||
out_shardings=(OpShardingSharding({devices=[8,1,1]0,1,2,3,4,5,6,7}), OpShardingSharding({replicated}))
|
||||
out_shardings=(GSPMDSharding({devices=[8,1,1]0,1,2,3,4,5,6,7}), GSPMDSharding({replicated}))
|
||||
resource_env=ResourceEnv(Mesh(device_ids=array([0, 1, 2, 3, 4, 5, 6, 7]), axis_names=('x',)), ())
|
||||
] a b
|
||||
in (c, d) }
|
||||
|
@ -43,7 +43,7 @@ from jax._src.config import config
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.lax import control_flow as cf
|
||||
from jax._src.sharding import OpShardingSharding
|
||||
from jax._src.sharding import GSPMDSharding
|
||||
from jax._src.typing import Array
|
||||
from jax._src.util import (as_hashable_function, split_list, safe_map, safe_zip,
|
||||
unzip3, weakref_lru_cache)
|
||||
@ -860,7 +860,7 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr,
|
||||
if jax.config.jax_array:
|
||||
sharding = pjit._UNSPECIFIED
|
||||
else:
|
||||
sharding = OpShardingSharding.get_replicated(
|
||||
sharding = GSPMDSharding.get_replicated(
|
||||
list(resource_env.physical_mesh.devices.flat))
|
||||
|
||||
new_in_shardings = (*[sharding] * num_error_vals, *in_shardings)
|
||||
|
@ -42,7 +42,7 @@ from jax._src.lax import control_flow as lcf
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.sharding import Sharding, OpShardingSharding, NamedSharding
|
||||
from jax._src.sharding import Sharding, GSPMDSharding, NamedSharding
|
||||
|
||||
# pytype: disable=import-error
|
||||
try:
|
||||
@ -310,7 +310,7 @@ def _inspect_sharding_lowering_rule(ctx: mlir.LoweringRuleContext, value, *,
|
||||
|
||||
def _op_sharding_callback(op_sharding: xc.OpSharding):
|
||||
if mesh.empty:
|
||||
return callback(OpShardingSharding(
|
||||
return callback(GSPMDSharding(
|
||||
devices, op_sharding))
|
||||
pspec = pjit.parse_flatten_op_sharding(
|
||||
op_sharding, mesh)[0].get_partition_spec()
|
||||
|
@ -61,7 +61,7 @@ from jax._src.lib import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.sharding import (PmapSharding, SingleDeviceSharding,
|
||||
OpShardingSharding, NamedSharding, PartitionSpec,
|
||||
GSPMDSharding, NamedSharding, PartitionSpec,
|
||||
Sharding)
|
||||
from jax._src.util import flatten, unflatten
|
||||
|
||||
@ -322,7 +322,7 @@ def not_none_device_or_backend_on_jit(backend, device, num_ins):
|
||||
assert len(da) == 1
|
||||
# in_shardings will be marked as replicated regardless of whatever the input
|
||||
# had. Given that only a single device is allowed above, this is correct.
|
||||
in_shardings = [OpShardingSharding.get_replicated(da)] * num_ins
|
||||
in_shardings = [GSPMDSharding.get_replicated(da)] * num_ins
|
||||
return da, in_shardings
|
||||
|
||||
|
||||
|
@ -2896,7 +2896,7 @@ def lower_sharding_computation(
|
||||
any(not _is_unspecified(js) for js, _ in jaxpr_sharding) or # type: ignore
|
||||
any(not _is_unspecified(o) for o in out_shardings)) # type: ignore
|
||||
|
||||
in_shardings = tuple(sharding_internal.OpShardingSharding.get_replicated(device_assignment)
|
||||
in_shardings = tuple(sharding_internal.GSPMDSharding.get_replicated(device_assignment)
|
||||
if _is_unspecified(i) else i for i in in_shardings)
|
||||
|
||||
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
|
||||
@ -3372,9 +3372,9 @@ def get_op_sharding_shardings_from_executable(
|
||||
|
||||
in_op_shardings, out_op_shardings = pjit._get_op_sharding_from_executable(xla_executable)
|
||||
|
||||
in_shardings_xla = [sharding_internal.OpShardingSharding(device_assignment, i)
|
||||
in_shardings_xla = [sharding_internal.GSPMDSharding(device_assignment, i)
|
||||
for i in in_op_shardings]
|
||||
out_shardings_xla = [sharding_internal.OpShardingSharding(device_assignment, o)
|
||||
out_shardings_xla = [sharding_internal.GSPMDSharding(device_assignment, o)
|
||||
for o in out_op_shardings]
|
||||
# This condition happens when all the elements in the output tuple have the
|
||||
# same sharding, so XLA decides to run the `FusionTupleDeduplicator` to
|
||||
@ -3720,7 +3720,7 @@ def _out_shardings_for_trivial(
|
||||
# a replicated sharding
|
||||
from jax._src import array
|
||||
|
||||
rep = sharding_internal.OpShardingSharding(
|
||||
rep = sharding_internal.GSPMDSharding(
|
||||
device_assignment, sharding_internal._get_replicated_op_sharding())
|
||||
shardings: Dict[core.Var, sharding_internal.XLACompatibleSharding] = {}
|
||||
for constvar, constval in zip(jaxpr.constvars, consts):
|
||||
|
@ -1860,7 +1860,7 @@ def _check_no_loop_collectives(jaxpr, loop_axis_resources):
|
||||
def _fix_inferred_spmd_sharding(jaxpr, resource_env, gen_fresh_name = None):
|
||||
from jax._src.pjit import (
|
||||
sharding_constraint_p, ParsedPartitionSpec, get_unconstrained_dims,
|
||||
OpShardingSharding)
|
||||
GSPMDSharding)
|
||||
|
||||
rec = lambda jaxpr: _fix_inferred_spmd_sharding(jaxpr, resource_env, gen_fresh_name)
|
||||
if isinstance(jaxpr, core.ClosedJaxpr):
|
||||
@ -1878,7 +1878,7 @@ def _fix_inferred_spmd_sharding(jaxpr, resource_env, gen_fresh_name = None):
|
||||
mps = NamedSharding._from_parsed_pspec(
|
||||
resource_env.physical_mesh, ParsedPartitionSpec((), ()))
|
||||
unconstrained_dims = get_unconstrained_dims(mps)
|
||||
op_sharding_sharding = OpShardingSharding.get_replicated(
|
||||
op_sharding_sharding = GSPMDSharding.get_replicated(
|
||||
mps._device_assignment)
|
||||
new_eqns.append(core.JaxprEqn(
|
||||
[tmpvar], [outvar], sharding_constraint_p,
|
||||
|
@ -37,7 +37,7 @@ from jax.tree_util import (
|
||||
treedef_tuple)
|
||||
|
||||
from jax._src.sharding import (
|
||||
NamedSharding, Sharding, XLACompatibleSharding, OpShardingSharding,
|
||||
NamedSharding, Sharding, XLACompatibleSharding, GSPMDSharding,
|
||||
XLADeviceAssignment, SingleDeviceSharding, PmapSharding)
|
||||
from jax._src import array
|
||||
from jax._src import dispatch
|
||||
@ -90,8 +90,8 @@ def _is_unspecified_or_from_gda_or_auto(x):
|
||||
return _is_from_gda(x) or is_auto(x) or _is_unspecified(x)
|
||||
|
||||
|
||||
PjitSharding = Union[OpShardingSharding, _UnspecifiedValue, _AUTOAxisResource]
|
||||
PjitShardingMinusUnspecified = Union[OpShardingSharding, _AUTOAxisResource]
|
||||
PjitSharding = Union[GSPMDSharding, _UnspecifiedValue, _AUTOAxisResource]
|
||||
PjitShardingMinusUnspecified = Union[GSPMDSharding, _AUTOAxisResource]
|
||||
MeshSharding = Union[NamedSharding, _UnspecifiedValue, _AUTOAxisResource]
|
||||
MeshShardingMinusUnspecified = Union[NamedSharding, _AUTOAxisResource]
|
||||
|
||||
@ -535,7 +535,7 @@ def common_infer_params(pjit_info_args, *args, **kwargs):
|
||||
in_positional_semantics = (
|
||||
pxla._PositionalSemantics.GLOBAL,) * len(consts) + in_positional_semantics
|
||||
|
||||
# in_shardings and out_shardings here are all OpShardingSharding.
|
||||
# in_shardings and out_shardings here are all GSPMDSharding.
|
||||
params = dict(
|
||||
jaxpr=jaxpr,
|
||||
in_shardings=canonicalized_in_shardings_flat,
|
||||
@ -1353,7 +1353,7 @@ class SameDeviceAssignmentTuple:
|
||||
device_assignment: Optional[XLADeviceAssignment]
|
||||
|
||||
def __hash__(self):
|
||||
shardings_hash = tuple(s._op_sharding_hash if isinstance(s, OpShardingSharding) else s # type: ignore
|
||||
shardings_hash = tuple(s._op_sharding_hash if isinstance(s, GSPMDSharding) else s # type: ignore
|
||||
for s in self.shardings)
|
||||
if self.device_assignment is None:
|
||||
return hash(shardings_hash)
|
||||
@ -1364,7 +1364,7 @@ class SameDeviceAssignmentTuple:
|
||||
if not isinstance(other, SameDeviceAssignmentTuple):
|
||||
return False
|
||||
return (all(pxla.are_op_shardings_equal(s._op_sharding, o._op_sharding) # pytype: disable=attribute-error
|
||||
if isinstance(s, OpShardingSharding) and isinstance(o, OpShardingSharding)
|
||||
if isinstance(s, GSPMDSharding) and isinstance(o, GSPMDSharding)
|
||||
else s == o
|
||||
for s, o in safe_zip(self.shardings, other.shardings)) and
|
||||
self.device_assignment == other.device_assignment)
|
||||
@ -1420,14 +1420,14 @@ def _pjit_lower_cached(
|
||||
Tuple[MeshShardingMinusUnspecified, ...], tuple(
|
||||
NamedSharding._from_parsed_pspec(
|
||||
mesh, parse_flatten_op_sharding(i._op_sharding, mesh)[0]) # type: ignore
|
||||
if isinstance(i, OpShardingSharding) else i
|
||||
if isinstance(i, GSPMDSharding) else i
|
||||
for i in in_shardings
|
||||
))
|
||||
out_shardings: Tuple[MeshSharding, ...] = cast( # type: ignore[no-redef]
|
||||
Tuple[MeshSharding, ...], tuple(
|
||||
NamedSharding._from_parsed_pspec(
|
||||
mesh, parse_flatten_op_sharding(o._op_sharding, mesh)[0]) # type: ignore
|
||||
if isinstance(o, OpShardingSharding) else o
|
||||
if isinstance(o, GSPMDSharding) else o
|
||||
for o in out_shardings
|
||||
))
|
||||
|
||||
@ -1560,7 +1560,7 @@ batching.axis_primitive_batchers[pjit_p] = partial(_pjit_batcher, False, None)
|
||||
pxla.spmd_primitive_batchers[pjit_p] = partial(_pjit_batcher, True, None)
|
||||
|
||||
def _pjit_batcher_for_sharding(
|
||||
s: Union[OpShardingSharding, _UnspecifiedValue],
|
||||
s: Union[GSPMDSharding, _UnspecifiedValue],
|
||||
dim: int, val: Tuple[str, ...], mesh, ndim: int):
|
||||
if _is_unspecified(s):
|
||||
return s
|
||||
@ -1569,14 +1569,14 @@ def _pjit_batcher_for_sharding(
|
||||
tad = list(new_op.tile_assignment_dimensions)
|
||||
tad.insert(dim, 1)
|
||||
new_op.tile_assignment_dimensions = tad
|
||||
return OpShardingSharding(s._device_assignment, new_op) # type: ignore
|
||||
return GSPMDSharding(s._device_assignment, new_op) # type: ignore
|
||||
else:
|
||||
assert isinstance(s, OpShardingSharding)
|
||||
assert isinstance(s, GSPMDSharding)
|
||||
assert mesh is not None and not mesh.empty
|
||||
parsed_pspec = parse_flatten_op_sharding(s._op_sharding, mesh)[0] # type: ignore
|
||||
parsed_pspec = parsed_pspec.insert_axis_partitions(dim, val)
|
||||
mps = NamedSharding._from_parsed_pspec(mesh, parsed_pspec)
|
||||
return OpShardingSharding(mps._device_assignment, mps._to_xla_op_sharding(ndim))
|
||||
return GSPMDSharding(mps._device_assignment, mps._to_xla_op_sharding(ndim))
|
||||
|
||||
|
||||
def _pjit_jvp(primals_in, tangents_in,
|
||||
@ -1642,7 +1642,7 @@ def _pjit_partial_eval(trace, *in_tracers,
|
||||
residual_shardings = (_UNSPECIFIED,) * num_residuals
|
||||
else:
|
||||
da = list(resource_env.physical_mesh.devices.flat)
|
||||
residual_shardings = (OpShardingSharding.get_replicated(da),) * num_residuals
|
||||
residual_shardings = (GSPMDSharding.get_replicated(da),) * num_residuals
|
||||
# Compute the known outputs
|
||||
known_params = dict(
|
||||
jaxpr=known_jaxpr,
|
||||
@ -1683,7 +1683,7 @@ def _pjit_partial_eval(trace, *in_tracers,
|
||||
residual_op_shardings = ()
|
||||
assert len(residual_shardings) == len(residual_op_shardings), (
|
||||
len(residual_shardings), len(residual_op_shardings))
|
||||
residual_shardings = tuple(OpShardingSharding(da, op) for op in residual_op_shardings)
|
||||
residual_shardings = tuple(GSPMDSharding(da, op) for op in residual_op_shardings)
|
||||
known_params['out_shardings'] = (
|
||||
keep_where(out_shardings, known_outs) + residual_shardings)
|
||||
|
||||
@ -2080,13 +2080,13 @@ def _sharding_constraint_hlo_lowering(ctx, x_node, *, sharding,
|
||||
aval, = ctx.avals_in
|
||||
axis_ctx = ctx.module_context.axis_context
|
||||
# axis_ctx and manual_axes is *only used with xmap* and xmap only works with
|
||||
# NamedSharding. So convert the OpShardingSharding to NamedSharding
|
||||
# NamedSharding. So convert the GSPMDSharding to NamedSharding
|
||||
# and then convert it back with the added special axes.
|
||||
if isinstance(axis_ctx, mlir.SPMDAxisContext):
|
||||
mesh = resource_env.physical_mesh
|
||||
parsed_pspec = parse_flatten_op_sharding(sharding._op_sharding, mesh)[0]
|
||||
mps = NamedSharding._from_parsed_pspec(mesh, parsed_pspec)
|
||||
sharding = OpShardingSharding(
|
||||
sharding = GSPMDSharding(
|
||||
mps._device_assignment, mps._to_xla_op_sharding(aval.ndim, axis_ctx=axis_ctx))
|
||||
return [
|
||||
mlir.wrap_with_sharding_op(
|
||||
@ -2152,10 +2152,10 @@ def get_array_mapping(
|
||||
if axes is not None for axis in axes)
|
||||
|
||||
|
||||
def to_op_sharding_sharding(s: XLACompatibleSharding, ndim: int) -> OpShardingSharding:
|
||||
if isinstance(s, OpShardingSharding):
|
||||
def to_op_sharding_sharding(s: XLACompatibleSharding, ndim: int) -> GSPMDSharding:
|
||||
if isinstance(s, GSPMDSharding):
|
||||
return s
|
||||
op_sharding_sharding = OpShardingSharding(
|
||||
op_sharding_sharding = GSPMDSharding(
|
||||
s._device_assignment, s._to_xla_op_sharding(ndim))
|
||||
op_sharding_sharding._original_sharding = s
|
||||
return op_sharding_sharding
|
||||
@ -2232,7 +2232,7 @@ def _maybe_replace_from_gda_with_pspec(
|
||||
|
||||
@lru_cache()
|
||||
def _gda_check_and_get_sharding(
|
||||
gda_sharding: NamedSharding, in_sharding: OpShardingSharding, ndim: int):
|
||||
gda_sharding: NamedSharding, in_sharding: GSPMDSharding, ndim: int):
|
||||
if not _is_from_gda(in_sharding) and not pxla.are_op_shardings_equal(
|
||||
gda_sharding._to_xla_op_sharding(ndim),
|
||||
in_sharding._to_xla_op_sharding(ndim)):
|
||||
|
@ -43,7 +43,7 @@ from jax._src.lib import gpu_prng
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.numpy import lax_numpy
|
||||
from jax._src.sharding import (
|
||||
NamedSharding, PmapSharding, OpShardingSharding)
|
||||
NamedSharding, PmapSharding, GSPMDSharding)
|
||||
from jax._src.util import canonicalize_axis, prod, safe_map, safe_zip
|
||||
|
||||
map, unsafe_map = safe_map, map
|
||||
@ -375,7 +375,7 @@ class KeyTyRules:
|
||||
if is_out_sharding_from_xla:
|
||||
phys_sharding = out_sharding
|
||||
else:
|
||||
phys_sharding = OpShardingSharding(
|
||||
phys_sharding = GSPMDSharding(
|
||||
out_sharding._device_assignment,
|
||||
KeyTyRules.physical_op_sharding(aval, out_sharding))
|
||||
|
||||
|
@ -24,6 +24,7 @@ import jax
|
||||
from jax._src import core
|
||||
from jax._src.util import safe_map, safe_zip, use_cpp_class, use_cpp_method
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax.interpreters import mlir
|
||||
from jax._src.interpreters import pxla
|
||||
|
||||
@ -75,10 +76,10 @@ class Sharding(metaclass=abc.ABCMeta):
|
||||
"""Returns True if two shardings put the same logical array
|
||||
(sharded/unsharded) on the same device(s).
|
||||
|
||||
For example, every XLACompatibleSharding lowers to OpShardingSharding which
|
||||
For example, every XLACompatibleSharding lowers to GSPMDSharding which
|
||||
is a general representation. So `jax.sharding.NamedSharding` is equivalent
|
||||
to `jax.sharding.PositionalSharding` if both of them lower to the same
|
||||
OpShardingSharding.
|
||||
GSPMDSharding.
|
||||
"""
|
||||
raise NotImplementedError('Subclasses should implement this method.')
|
||||
|
||||
@ -136,7 +137,7 @@ class XLACompatibleSharding(Sharding, metaclass=abc.ABCMeta):
|
||||
@functools.lru_cache(maxsize=4096)
|
||||
def devices_indices_map(self, global_shape: Shape) -> Mapping[Device, Index]:
|
||||
op_sharding = self._to_xla_op_sharding(len(global_shape))
|
||||
op_sharding_sharding = OpShardingSharding(self._device_assignment,
|
||||
op_sharding_sharding = GSPMDSharding(self._device_assignment,
|
||||
op_sharding)
|
||||
return op_sharding_sharding.devices_indices_map(global_shape)
|
||||
|
||||
@ -645,8 +646,8 @@ class DeviceIdSet:
|
||||
self._ids == other._ids)
|
||||
|
||||
|
||||
@use_cpp_class(xc.OpShardingSharding)
|
||||
class OpShardingSharding(XLACompatibleSharding):
|
||||
@use_cpp_class(xc.GSPMDSharding if xla_extension_version >= 129 else xc.OpShardingSharding) # type: ignore
|
||||
class GSPMDSharding(XLACompatibleSharding):
|
||||
|
||||
@use_cpp_method()
|
||||
def __init__(self, devices: Sequence[Device], op_sharding: xc.OpSharding):
|
||||
@ -661,7 +662,7 @@ class OpShardingSharding(XLACompatibleSharding):
|
||||
return hash(xc.HloSharding.from_proto(self._op_sharding))
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, OpShardingSharding):
|
||||
if not isinstance(other, GSPMDSharding):
|
||||
return False
|
||||
if id(self) == id(other):
|
||||
return True
|
||||
@ -674,7 +675,7 @@ class OpShardingSharding(XLACompatibleSharding):
|
||||
return self._hash
|
||||
|
||||
def __repr__(self):
|
||||
return f'OpShardingSharding({repr(xc.HloSharding.from_proto(self._op_sharding))})'
|
||||
return f'GSPMDSharding({repr(xc.HloSharding.from_proto(self._op_sharding))})'
|
||||
|
||||
def is_compatible_aval(self, aval_shape: Shape):
|
||||
num_ways_dim_sharded, _ = pxla.get_num_ways_dim_sharded(self._op_sharding)
|
||||
@ -705,3 +706,8 @@ class OpShardingSharding(XLACompatibleSharding):
|
||||
def get_replicated(cls, device_assignment):
|
||||
proto = _get_replicated_op_sharding()
|
||||
return cls(device_assignment, proto)
|
||||
|
||||
|
||||
# TODO(yashkatariya); Remove OpShardingSharding after 3 months from Feb 17, 2023
|
||||
# per the deprecation policy.
|
||||
OpShardingSharding = GSPMDSharding
|
||||
|
@ -385,8 +385,8 @@ def _custom_partitioning_lowering_rule(ctx: mlir.LoweringRuleContext, *values,
|
||||
|
||||
def to_mesh_pspec_sharding(op_sharding: xc.OpSharding):
|
||||
if mesh.empty:
|
||||
from jax._src.sharding import OpShardingSharding
|
||||
return OpShardingSharding(devices, op_sharding)
|
||||
from jax._src.sharding import GSPMDSharding
|
||||
return GSPMDSharding(devices, op_sharding)
|
||||
pspec = pjit.parse_flatten_op_sharding(op_sharding,
|
||||
mesh)[0].get_partition_spec()
|
||||
return jax.sharding.NamedSharding(mesh, pspec)
|
||||
|
@ -22,7 +22,7 @@ from jax._src import util
|
||||
from jax._src import config as jax_config
|
||||
from jax.config import config
|
||||
from jax._src import array
|
||||
from jax._src.sharding import NamedSharding, OpShardingSharding
|
||||
from jax._src.sharding import NamedSharding, GSPMDSharding
|
||||
from jax.sharding import PartitionSpec as P
|
||||
from jax.experimental.gda_serialization import serialization
|
||||
import numpy as np
|
||||
@ -133,7 +133,7 @@ class CheckpointTest(jtu.JaxTestCase):
|
||||
for l in m1.addressable_shards:
|
||||
self.assertArraysEqual(np.asarray(l.data), expected_data[l.device.id])
|
||||
|
||||
new_ds = OpShardingSharding.get_replicated(list(global_mesh.devices.flat))
|
||||
new_ds = GSPMDSharding.get_replicated(list(global_mesh.devices.flat))
|
||||
m2, = serialization.run_deserialization([new_ds], tspecs, [(8, 2)], [np.float32])
|
||||
for l in m2.addressable_shards:
|
||||
self.assertArraysEqual(l.data, global_input_data1.astype('float32'))
|
||||
|
@ -96,7 +96,7 @@ def _identity_fn(x):
|
||||
|
||||
def _handle_array_process_allgather(inp, tiled):
|
||||
if isinstance(inp, array.ArrayImpl) and not inp.is_fully_addressable:
|
||||
reps = sharding.OpShardingSharding(inp.sharding._device_assignment,
|
||||
reps = sharding.GSPMDSharding(inp.sharding._device_assignment,
|
||||
sharding._get_replicated_op_sharding())
|
||||
out = pjit(_identity_fn, out_axis_resources=reps)(inp)
|
||||
else:
|
||||
|
@ -18,13 +18,16 @@
|
||||
from jax._src.sharding import (
|
||||
Sharding as Sharding,
|
||||
XLACompatibleSharding as XLACompatibleSharding,
|
||||
# TODO(yashkatariya): Deprecate MeshPspecSharding in 3 months.
|
||||
# TODO(yashkatariya): Remove MeshPspecSharding in 3 months.
|
||||
MeshPspecSharding as MeshPspecSharding,
|
||||
# New name of MeshPspecSharding to match PositionalSharding below.
|
||||
NamedSharding as NamedSharding,
|
||||
PartitionSpec as PartitionSpec,
|
||||
SingleDeviceSharding as SingleDeviceSharding,
|
||||
PmapSharding as PmapSharding,
|
||||
GSPMDSharding as GSPMDSharding,
|
||||
# TODO(yashkatariya): Remove OpShardingSharding in 3 months from
|
||||
# Feb 17, 2023.
|
||||
OpShardingSharding as OpShardingSharding,
|
||||
PositionalSharding as PositionalSharding,
|
||||
)
|
||||
|
@ -694,7 +694,7 @@ class ShardingTest(jtu.JaxTestCase):
|
||||
shape = (8, 4)
|
||||
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
mps = sharding.NamedSharding(mesh, pspec)
|
||||
ops = sharding.OpShardingSharding(
|
||||
ops = sharding.GSPMDSharding(
|
||||
list(mesh.devices.flat), mps._to_xla_op_sharding(len(shape)))
|
||||
self.assertDictEqual(
|
||||
ops.devices_indices_map(shape), mps.devices_indices_map(shape))
|
||||
@ -771,16 +771,16 @@ class ShardingTest(jtu.JaxTestCase):
|
||||
op.tile_assignment_dimensions = [4, 1, 2]
|
||||
op.tile_assignment_devices = [0, 1, 2, 3, 4, 5, 6, 7]
|
||||
op.replicate_on_last_tile_dim = True
|
||||
s = sharding.OpShardingSharding(jax.devices(), op)
|
||||
s = sharding.GSPMDSharding(jax.devices(), op)
|
||||
self.assertEqual(
|
||||
repr(s),
|
||||
'OpShardingSharding({devices=[4,1,2]0,1,2,3,4,5,6,7 '
|
||||
'GSPMDSharding({devices=[4,1,2]0,1,2,3,4,5,6,7 '
|
||||
'last_tile_dim_replicate})')
|
||||
|
||||
op2 = xc.OpSharding()
|
||||
op2.type = xc.OpSharding.Type.REPLICATED
|
||||
s2 = sharding.OpShardingSharding(jax.devices(), op2)
|
||||
self.assertEqual(repr(s2), 'OpShardingSharding({replicated})')
|
||||
s2 = sharding.GSPMDSharding(jax.devices(), op2)
|
||||
self.assertEqual(repr(s2), 'GSPMDSharding({replicated})')
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("mesh_x_y", P("x", "y"), (4, 2), (), False),
|
||||
@ -876,9 +876,9 @@ class ShardingTest(jtu.JaxTestCase):
|
||||
|
||||
op1 = xc.OpSharding()
|
||||
op1.type = xc.OpSharding.Type.REPLICATED
|
||||
s6 = jax.sharding.OpShardingSharding([jax.devices()[0]], op1)
|
||||
s6 = jax.sharding.GSPMDSharding([jax.devices()[0]], op1)
|
||||
|
||||
s7 = jax.sharding.OpShardingSharding(jax.devices(), op1)
|
||||
s7 = jax.sharding.GSPMDSharding(jax.devices(), op1)
|
||||
|
||||
# The OpSharding is replicated but the Sharding itself are on different
|
||||
# devices.
|
||||
@ -888,7 +888,7 @@ class ShardingTest(jtu.JaxTestCase):
|
||||
op2.type = xc.OpSharding.Type.OTHER
|
||||
op2.tile_assignment_devices = [0, 1]
|
||||
op2.tile_assignment_dimensions = [2, 1]
|
||||
s8 = jax.sharding.OpShardingSharding(list(mesh2.devices.flat), op2)
|
||||
s8 = jax.sharding.GSPMDSharding(list(mesh2.devices.flat), op2)
|
||||
|
||||
self.assertTrue(s1.is_equivalent_to(s6, 2))
|
||||
self.assertTrue(s5.is_equivalent_to(s8, 2))
|
||||
@ -901,7 +901,7 @@ class ShardingTest(jtu.JaxTestCase):
|
||||
op3.tile_assignment_devices = [0, 1]
|
||||
op3.tile_assignment_dimensions = [1, 1, 2]
|
||||
op3.replicate_on_last_tile_dim = True
|
||||
s10 = jax.sharding.OpShardingSharding(list(mesh2.devices.flat), op3)
|
||||
s10 = jax.sharding.GSPMDSharding(list(mesh2.devices.flat), op3)
|
||||
|
||||
self.assertTrue(s9.is_equivalent_to(s10, 2))
|
||||
|
||||
|
@ -185,7 +185,7 @@ class PickleTest(jtu.JaxTestCase):
|
||||
def test_pickle_op_sharding_sharding(self):
|
||||
op_sharding = xla.xc.OpSharding()
|
||||
op_sharding.type = xla.xc.OpSharding.Type.REPLICATED
|
||||
s = jax.sharding.OpShardingSharding(jax.devices(), op_sharding)
|
||||
s = jax.sharding.GSPMDSharding(jax.devices(), op_sharding)
|
||||
self.assertEqual(s, pickle.loads(pickle.dumps(s)))
|
||||
|
||||
@unittest.skipIf(cloudpickle is None, "Requires cloudpickle")
|
||||
|
@ -44,7 +44,7 @@ from jax.experimental import global_device_array
|
||||
from jax.experimental import multihost_utils
|
||||
from jax.experimental.custom_partitioning import custom_partitioning
|
||||
from jax._src import array
|
||||
from jax._src.sharding import NamedSharding, Sharding, OpShardingSharding
|
||||
from jax._src.sharding import NamedSharding, Sharding, GSPMDSharding
|
||||
import jax._src.pjit as pjit_lib
|
||||
from jax._src.pjit import (pjit, pjit_p, FROM_GDA, AUTO)
|
||||
from jax._src.interpreters import pxla
|
||||
@ -1860,7 +1860,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
|
||||
def _checks(out, input_data):
|
||||
self.assertIsInstance(out, array.ArrayImpl)
|
||||
self.assertIsInstance(out.sharding, OpShardingSharding)
|
||||
self.assertIsInstance(out.sharding, GSPMDSharding)
|
||||
self.assertEqual(out.shape, (8, 2))
|
||||
self.assertEqual(out.addressable_shards[0].data.shape, (2, 1))
|
||||
for s in out.addressable_shards:
|
||||
@ -2416,20 +2416,20 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
|
||||
f = pjit(lambda x: x)
|
||||
out1 = f(arr)
|
||||
self.assertIsInstance(out1.sharding, OpShardingSharding)
|
||||
self.assertIsInstance(out1.sharding, GSPMDSharding)
|
||||
out1.sharding.devices_indices_map(shape)
|
||||
cache_info1 = OpShardingSharding.devices_indices_map.cache_info()
|
||||
cache_info1 = GSPMDSharding.devices_indices_map.cache_info()
|
||||
|
||||
out2 = f(out1)
|
||||
self.assertIsInstance(out2.sharding, OpShardingSharding)
|
||||
self.assertIsInstance(out2.sharding, GSPMDSharding)
|
||||
out2.sharding.devices_indices_map(shape)
|
||||
cache_info2 = OpShardingSharding.devices_indices_map.cache_info()
|
||||
cache_info2 = GSPMDSharding.devices_indices_map.cache_info()
|
||||
self.assertEqual(cache_info2.hits, cache_info1.hits + 1)
|
||||
|
||||
out3 = f(out2)
|
||||
self.assertIsInstance(out3.sharding, OpShardingSharding)
|
||||
self.assertIsInstance(out3.sharding, GSPMDSharding)
|
||||
out3.sharding.devices_indices_map(shape)
|
||||
cache_info3 = OpShardingSharding.devices_indices_map.cache_info()
|
||||
cache_info3 = GSPMDSharding.devices_indices_map.cache_info()
|
||||
self.assertEqual(cache_info3.hits, cache_info2.hits + 1)
|
||||
|
||||
@jax_array(True)
|
||||
@ -3871,21 +3871,21 @@ class UtilTest(jtu.JaxTestCase):
|
||||
shape = (8, 4)
|
||||
devices = jax.devices()
|
||||
|
||||
ops = OpShardingSharding(devices, op1)
|
||||
ops = GSPMDSharding(devices, op1)
|
||||
ops.devices_indices_map(shape)
|
||||
cache_info1 = OpShardingSharding.devices_indices_map.cache_info()
|
||||
cache_info1 = GSPMDSharding.devices_indices_map.cache_info()
|
||||
|
||||
ops.devices_indices_map(shape)
|
||||
cache_info2 = OpShardingSharding.devices_indices_map.cache_info()
|
||||
cache_info2 = GSPMDSharding.devices_indices_map.cache_info()
|
||||
self.assertEqual(cache_info2.hits, cache_info1.hits + 1)
|
||||
|
||||
ops = OpShardingSharding(devices, op2)
|
||||
ops = GSPMDSharding(devices, op2)
|
||||
ops.devices_indices_map(shape)
|
||||
cache_info3 = OpShardingSharding.devices_indices_map.cache_info()
|
||||
cache_info3 = GSPMDSharding.devices_indices_map.cache_info()
|
||||
self.assertEqual(cache_info3.hits, cache_info2.hits + 1)
|
||||
|
||||
ops.devices_indices_map(shape)
|
||||
cache_info4 = OpShardingSharding.devices_indices_map.cache_info()
|
||||
cache_info4 = GSPMDSharding.devices_indices_map.cache_info()
|
||||
self.assertEqual(cache_info4.hits, cache_info3.hits + 1)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user