Rename jax.sharding.OpShardingSharding to jax.sharding.GSPMDSharding. jax.sharding.OpShardingSharding will be removed in 3 months from Feb 17, 2023.

PiperOrigin-RevId: 510556189
This commit is contained in:
Yash Katariya 2023-02-17 17:10:27 -08:00 committed by jax authors
parent f7734fd6a4
commit 0ffdeb3de2
17 changed files with 86 additions and 73 deletions

View File

@ -8,6 +8,10 @@ Remember to align the itemized text with the first line of an item within a list
## jax 0.4.5
* Deprecations
* `jax.sharding.OpShardingSharding` has been renamed to `jax.sharding.GSPMDSharding`.
`jax.sharding.OpShardingSharding` will be removed in 3 months from Feb 17, 2023.
## jaxlib 0.4.5
## jax 0.4.4 (Feb 16, 2023)

View File

@ -655,7 +655,7 @@ with mesh:
c:bf16[32,512,512] d:bf16[512,512] = pjit[
donated_invars=(False, False)
in_positional_semantics=(<_PositionalSemantics.GLOBAL: 1>, <_PositionalSemantics.GLOBAL: 1>)
in_shardings=(OpShardingSharding({devices=[8,1,1]0,1,2,3,4,5,6,7}), OpShardingSharding({replicated}))
in_shardings=(GSPMDSharding({devices=[8,1,1]0,1,2,3,4,5,6,7}), GSPMDSharding({replicated}))
jaxpr={ lambda ; e:bf16[32,512,512] f:bf16[512,512]. let
g:bf16[8,4,512,512] = reshape[
dimensions=None
@ -734,7 +734,7 @@ with mesh:
in (bn, bf) }
name=<unnamed function>
out_positional_semantics=_PositionalSemantics.GLOBAL
out_shardings=(OpShardingSharding({devices=[8,1,1]0,1,2,3,4,5,6,7}), OpShardingSharding({replicated}))
out_shardings=(GSPMDSharding({devices=[8,1,1]0,1,2,3,4,5,6,7}), GSPMDSharding({replicated}))
resource_env=ResourceEnv(Mesh(device_ids=array([0, 1, 2, 3, 4, 5, 6, 7]), axis_names=('x',)), ())
] a b
in (c, d) }

View File

@ -43,7 +43,7 @@ from jax._src.config import config
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.lax import control_flow as cf
from jax._src.sharding import OpShardingSharding
from jax._src.sharding import GSPMDSharding
from jax._src.typing import Array
from jax._src.util import (as_hashable_function, split_list, safe_map, safe_zip,
unzip3, weakref_lru_cache)
@ -860,7 +860,7 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr,
if jax.config.jax_array:
sharding = pjit._UNSPECIFIED
else:
sharding = OpShardingSharding.get_replicated(
sharding = GSPMDSharding.get_replicated(
list(resource_env.physical_mesh.devices.flat))
new_in_shardings = (*[sharding] * num_error_vals, *in_shardings)

View File

@ -42,7 +42,7 @@ from jax._src.lax import control_flow as lcf
from jax._src.lib import xla_client as xc
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.sharding import Sharding, OpShardingSharding, NamedSharding
from jax._src.sharding import Sharding, GSPMDSharding, NamedSharding
# pytype: disable=import-error
try:
@ -310,7 +310,7 @@ def _inspect_sharding_lowering_rule(ctx: mlir.LoweringRuleContext, value, *,
def _op_sharding_callback(op_sharding: xc.OpSharding):
if mesh.empty:
return callback(OpShardingSharding(
return callback(GSPMDSharding(
devices, op_sharding))
pspec = pjit.parse_flatten_op_sharding(
op_sharding, mesh)[0].get_partition_spec()

View File

@ -61,7 +61,7 @@ from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.sharding import (PmapSharding, SingleDeviceSharding,
OpShardingSharding, NamedSharding, PartitionSpec,
GSPMDSharding, NamedSharding, PartitionSpec,
Sharding)
from jax._src.util import flatten, unflatten
@ -322,7 +322,7 @@ def not_none_device_or_backend_on_jit(backend, device, num_ins):
assert len(da) == 1
# in_shardings will be marked as replicated regardless of whatever the input
# had. Given that only a single device is allowed above, this is correct.
in_shardings = [OpShardingSharding.get_replicated(da)] * num_ins
in_shardings = [GSPMDSharding.get_replicated(da)] * num_ins
return da, in_shardings

View File

@ -2896,7 +2896,7 @@ def lower_sharding_computation(
any(not _is_unspecified(js) for js, _ in jaxpr_sharding) or # type: ignore
any(not _is_unspecified(o) for o in out_shardings)) # type: ignore
in_shardings = tuple(sharding_internal.OpShardingSharding.get_replicated(device_assignment)
in_shardings = tuple(sharding_internal.GSPMDSharding.get_replicated(device_assignment)
if _is_unspecified(i) else i for i in in_shardings)
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
@ -3372,9 +3372,9 @@ def get_op_sharding_shardings_from_executable(
in_op_shardings, out_op_shardings = pjit._get_op_sharding_from_executable(xla_executable)
in_shardings_xla = [sharding_internal.OpShardingSharding(device_assignment, i)
in_shardings_xla = [sharding_internal.GSPMDSharding(device_assignment, i)
for i in in_op_shardings]
out_shardings_xla = [sharding_internal.OpShardingSharding(device_assignment, o)
out_shardings_xla = [sharding_internal.GSPMDSharding(device_assignment, o)
for o in out_op_shardings]
# This condition happens when all the elements in the output tuple have the
# same sharding, so XLA decides to run the `FusionTupleDeduplicator` to
@ -3720,7 +3720,7 @@ def _out_shardings_for_trivial(
# a replicated sharding
from jax._src import array
rep = sharding_internal.OpShardingSharding(
rep = sharding_internal.GSPMDSharding(
device_assignment, sharding_internal._get_replicated_op_sharding())
shardings: Dict[core.Var, sharding_internal.XLACompatibleSharding] = {}
for constvar, constval in zip(jaxpr.constvars, consts):

View File

@ -1860,7 +1860,7 @@ def _check_no_loop_collectives(jaxpr, loop_axis_resources):
def _fix_inferred_spmd_sharding(jaxpr, resource_env, gen_fresh_name = None):
from jax._src.pjit import (
sharding_constraint_p, ParsedPartitionSpec, get_unconstrained_dims,
OpShardingSharding)
GSPMDSharding)
rec = lambda jaxpr: _fix_inferred_spmd_sharding(jaxpr, resource_env, gen_fresh_name)
if isinstance(jaxpr, core.ClosedJaxpr):
@ -1878,7 +1878,7 @@ def _fix_inferred_spmd_sharding(jaxpr, resource_env, gen_fresh_name = None):
mps = NamedSharding._from_parsed_pspec(
resource_env.physical_mesh, ParsedPartitionSpec((), ()))
unconstrained_dims = get_unconstrained_dims(mps)
op_sharding_sharding = OpShardingSharding.get_replicated(
op_sharding_sharding = GSPMDSharding.get_replicated(
mps._device_assignment)
new_eqns.append(core.JaxprEqn(
[tmpvar], [outvar], sharding_constraint_p,

View File

@ -37,7 +37,7 @@ from jax.tree_util import (
treedef_tuple)
from jax._src.sharding import (
NamedSharding, Sharding, XLACompatibleSharding, OpShardingSharding,
NamedSharding, Sharding, XLACompatibleSharding, GSPMDSharding,
XLADeviceAssignment, SingleDeviceSharding, PmapSharding)
from jax._src import array
from jax._src import dispatch
@ -90,8 +90,8 @@ def _is_unspecified_or_from_gda_or_auto(x):
return _is_from_gda(x) or is_auto(x) or _is_unspecified(x)
PjitSharding = Union[OpShardingSharding, _UnspecifiedValue, _AUTOAxisResource]
PjitShardingMinusUnspecified = Union[OpShardingSharding, _AUTOAxisResource]
PjitSharding = Union[GSPMDSharding, _UnspecifiedValue, _AUTOAxisResource]
PjitShardingMinusUnspecified = Union[GSPMDSharding, _AUTOAxisResource]
MeshSharding = Union[NamedSharding, _UnspecifiedValue, _AUTOAxisResource]
MeshShardingMinusUnspecified = Union[NamedSharding, _AUTOAxisResource]
@ -535,7 +535,7 @@ def common_infer_params(pjit_info_args, *args, **kwargs):
in_positional_semantics = (
pxla._PositionalSemantics.GLOBAL,) * len(consts) + in_positional_semantics
# in_shardings and out_shardings here are all OpShardingSharding.
# in_shardings and out_shardings here are all GSPMDSharding.
params = dict(
jaxpr=jaxpr,
in_shardings=canonicalized_in_shardings_flat,
@ -1353,7 +1353,7 @@ class SameDeviceAssignmentTuple:
device_assignment: Optional[XLADeviceAssignment]
def __hash__(self):
shardings_hash = tuple(s._op_sharding_hash if isinstance(s, OpShardingSharding) else s # type: ignore
shardings_hash = tuple(s._op_sharding_hash if isinstance(s, GSPMDSharding) else s # type: ignore
for s in self.shardings)
if self.device_assignment is None:
return hash(shardings_hash)
@ -1364,7 +1364,7 @@ class SameDeviceAssignmentTuple:
if not isinstance(other, SameDeviceAssignmentTuple):
return False
return (all(pxla.are_op_shardings_equal(s._op_sharding, o._op_sharding) # pytype: disable=attribute-error
if isinstance(s, OpShardingSharding) and isinstance(o, OpShardingSharding)
if isinstance(s, GSPMDSharding) and isinstance(o, GSPMDSharding)
else s == o
for s, o in safe_zip(self.shardings, other.shardings)) and
self.device_assignment == other.device_assignment)
@ -1420,14 +1420,14 @@ def _pjit_lower_cached(
Tuple[MeshShardingMinusUnspecified, ...], tuple(
NamedSharding._from_parsed_pspec(
mesh, parse_flatten_op_sharding(i._op_sharding, mesh)[0]) # type: ignore
if isinstance(i, OpShardingSharding) else i
if isinstance(i, GSPMDSharding) else i
for i in in_shardings
))
out_shardings: Tuple[MeshSharding, ...] = cast( # type: ignore[no-redef]
Tuple[MeshSharding, ...], tuple(
NamedSharding._from_parsed_pspec(
mesh, parse_flatten_op_sharding(o._op_sharding, mesh)[0]) # type: ignore
if isinstance(o, OpShardingSharding) else o
if isinstance(o, GSPMDSharding) else o
for o in out_shardings
))
@ -1560,7 +1560,7 @@ batching.axis_primitive_batchers[pjit_p] = partial(_pjit_batcher, False, None)
pxla.spmd_primitive_batchers[pjit_p] = partial(_pjit_batcher, True, None)
def _pjit_batcher_for_sharding(
s: Union[OpShardingSharding, _UnspecifiedValue],
s: Union[GSPMDSharding, _UnspecifiedValue],
dim: int, val: Tuple[str, ...], mesh, ndim: int):
if _is_unspecified(s):
return s
@ -1569,14 +1569,14 @@ def _pjit_batcher_for_sharding(
tad = list(new_op.tile_assignment_dimensions)
tad.insert(dim, 1)
new_op.tile_assignment_dimensions = tad
return OpShardingSharding(s._device_assignment, new_op) # type: ignore
return GSPMDSharding(s._device_assignment, new_op) # type: ignore
else:
assert isinstance(s, OpShardingSharding)
assert isinstance(s, GSPMDSharding)
assert mesh is not None and not mesh.empty
parsed_pspec = parse_flatten_op_sharding(s._op_sharding, mesh)[0] # type: ignore
parsed_pspec = parsed_pspec.insert_axis_partitions(dim, val)
mps = NamedSharding._from_parsed_pspec(mesh, parsed_pspec)
return OpShardingSharding(mps._device_assignment, mps._to_xla_op_sharding(ndim))
return GSPMDSharding(mps._device_assignment, mps._to_xla_op_sharding(ndim))
def _pjit_jvp(primals_in, tangents_in,
@ -1642,7 +1642,7 @@ def _pjit_partial_eval(trace, *in_tracers,
residual_shardings = (_UNSPECIFIED,) * num_residuals
else:
da = list(resource_env.physical_mesh.devices.flat)
residual_shardings = (OpShardingSharding.get_replicated(da),) * num_residuals
residual_shardings = (GSPMDSharding.get_replicated(da),) * num_residuals
# Compute the known outputs
known_params = dict(
jaxpr=known_jaxpr,
@ -1683,7 +1683,7 @@ def _pjit_partial_eval(trace, *in_tracers,
residual_op_shardings = ()
assert len(residual_shardings) == len(residual_op_shardings), (
len(residual_shardings), len(residual_op_shardings))
residual_shardings = tuple(OpShardingSharding(da, op) for op in residual_op_shardings)
residual_shardings = tuple(GSPMDSharding(da, op) for op in residual_op_shardings)
known_params['out_shardings'] = (
keep_where(out_shardings, known_outs) + residual_shardings)
@ -2080,13 +2080,13 @@ def _sharding_constraint_hlo_lowering(ctx, x_node, *, sharding,
aval, = ctx.avals_in
axis_ctx = ctx.module_context.axis_context
# axis_ctx and manual_axes is *only used with xmap* and xmap only works with
# NamedSharding. So convert the OpShardingSharding to NamedSharding
# NamedSharding. So convert the GSPMDSharding to NamedSharding
# and then convert it back with the added special axes.
if isinstance(axis_ctx, mlir.SPMDAxisContext):
mesh = resource_env.physical_mesh
parsed_pspec = parse_flatten_op_sharding(sharding._op_sharding, mesh)[0]
mps = NamedSharding._from_parsed_pspec(mesh, parsed_pspec)
sharding = OpShardingSharding(
sharding = GSPMDSharding(
mps._device_assignment, mps._to_xla_op_sharding(aval.ndim, axis_ctx=axis_ctx))
return [
mlir.wrap_with_sharding_op(
@ -2152,10 +2152,10 @@ def get_array_mapping(
if axes is not None for axis in axes)
def to_op_sharding_sharding(s: XLACompatibleSharding, ndim: int) -> OpShardingSharding:
if isinstance(s, OpShardingSharding):
def to_op_sharding_sharding(s: XLACompatibleSharding, ndim: int) -> GSPMDSharding:
if isinstance(s, GSPMDSharding):
return s
op_sharding_sharding = OpShardingSharding(
op_sharding_sharding = GSPMDSharding(
s._device_assignment, s._to_xla_op_sharding(ndim))
op_sharding_sharding._original_sharding = s
return op_sharding_sharding
@ -2232,7 +2232,7 @@ def _maybe_replace_from_gda_with_pspec(
@lru_cache()
def _gda_check_and_get_sharding(
gda_sharding: NamedSharding, in_sharding: OpShardingSharding, ndim: int):
gda_sharding: NamedSharding, in_sharding: GSPMDSharding, ndim: int):
if not _is_from_gda(in_sharding) and not pxla.are_op_shardings_equal(
gda_sharding._to_xla_op_sharding(ndim),
in_sharding._to_xla_op_sharding(ndim)):

View File

@ -43,7 +43,7 @@ from jax._src.lib import gpu_prng
from jax._src.lib.mlir.dialects import hlo
from jax._src.numpy import lax_numpy
from jax._src.sharding import (
NamedSharding, PmapSharding, OpShardingSharding)
NamedSharding, PmapSharding, GSPMDSharding)
from jax._src.util import canonicalize_axis, prod, safe_map, safe_zip
map, unsafe_map = safe_map, map
@ -375,7 +375,7 @@ class KeyTyRules:
if is_out_sharding_from_xla:
phys_sharding = out_sharding
else:
phys_sharding = OpShardingSharding(
phys_sharding = GSPMDSharding(
out_sharding._device_assignment,
KeyTyRules.physical_op_sharding(aval, out_sharding))

View File

@ -24,6 +24,7 @@ import jax
from jax._src import core
from jax._src.util import safe_map, safe_zip, use_cpp_class, use_cpp_method
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax.interpreters import mlir
from jax._src.interpreters import pxla
@ -75,10 +76,10 @@ class Sharding(metaclass=abc.ABCMeta):
"""Returns True if two shardings put the same logical array
(sharded/unsharded) on the same device(s).
For example, every XLACompatibleSharding lowers to OpShardingSharding which
For example, every XLACompatibleSharding lowers to GSPMDSharding which
is a general representation. So `jax.sharding.NamedSharding` is equivalent
to `jax.sharding.PositionalSharding` if both of them lower to the same
OpShardingSharding.
GSPMDSharding.
"""
raise NotImplementedError('Subclasses should implement this method.')
@ -136,7 +137,7 @@ class XLACompatibleSharding(Sharding, metaclass=abc.ABCMeta):
@functools.lru_cache(maxsize=4096)
def devices_indices_map(self, global_shape: Shape) -> Mapping[Device, Index]:
op_sharding = self._to_xla_op_sharding(len(global_shape))
op_sharding_sharding = OpShardingSharding(self._device_assignment,
op_sharding_sharding = GSPMDSharding(self._device_assignment,
op_sharding)
return op_sharding_sharding.devices_indices_map(global_shape)
@ -645,8 +646,8 @@ class DeviceIdSet:
self._ids == other._ids)
@use_cpp_class(xc.OpShardingSharding)
class OpShardingSharding(XLACompatibleSharding):
@use_cpp_class(xc.GSPMDSharding if xla_extension_version >= 129 else xc.OpShardingSharding) # type: ignore
class GSPMDSharding(XLACompatibleSharding):
@use_cpp_method()
def __init__(self, devices: Sequence[Device], op_sharding: xc.OpSharding):
@ -661,7 +662,7 @@ class OpShardingSharding(XLACompatibleSharding):
return hash(xc.HloSharding.from_proto(self._op_sharding))
def __eq__(self, other):
if not isinstance(other, OpShardingSharding):
if not isinstance(other, GSPMDSharding):
return False
if id(self) == id(other):
return True
@ -674,7 +675,7 @@ class OpShardingSharding(XLACompatibleSharding):
return self._hash
def __repr__(self):
return f'OpShardingSharding({repr(xc.HloSharding.from_proto(self._op_sharding))})'
return f'GSPMDSharding({repr(xc.HloSharding.from_proto(self._op_sharding))})'
def is_compatible_aval(self, aval_shape: Shape):
num_ways_dim_sharded, _ = pxla.get_num_ways_dim_sharded(self._op_sharding)
@ -705,3 +706,8 @@ class OpShardingSharding(XLACompatibleSharding):
def get_replicated(cls, device_assignment):
proto = _get_replicated_op_sharding()
return cls(device_assignment, proto)
# TODO(yashkatariya); Remove OpShardingSharding after 3 months from Feb 17, 2023
# per the deprecation policy.
OpShardingSharding = GSPMDSharding

View File

@ -385,8 +385,8 @@ def _custom_partitioning_lowering_rule(ctx: mlir.LoweringRuleContext, *values,
def to_mesh_pspec_sharding(op_sharding: xc.OpSharding):
if mesh.empty:
from jax._src.sharding import OpShardingSharding
return OpShardingSharding(devices, op_sharding)
from jax._src.sharding import GSPMDSharding
return GSPMDSharding(devices, op_sharding)
pspec = pjit.parse_flatten_op_sharding(op_sharding,
mesh)[0].get_partition_spec()
return jax.sharding.NamedSharding(mesh, pspec)

View File

@ -22,7 +22,7 @@ from jax._src import util
from jax._src import config as jax_config
from jax.config import config
from jax._src import array
from jax._src.sharding import NamedSharding, OpShardingSharding
from jax._src.sharding import NamedSharding, GSPMDSharding
from jax.sharding import PartitionSpec as P
from jax.experimental.gda_serialization import serialization
import numpy as np
@ -133,7 +133,7 @@ class CheckpointTest(jtu.JaxTestCase):
for l in m1.addressable_shards:
self.assertArraysEqual(np.asarray(l.data), expected_data[l.device.id])
new_ds = OpShardingSharding.get_replicated(list(global_mesh.devices.flat))
new_ds = GSPMDSharding.get_replicated(list(global_mesh.devices.flat))
m2, = serialization.run_deserialization([new_ds], tspecs, [(8, 2)], [np.float32])
for l in m2.addressable_shards:
self.assertArraysEqual(l.data, global_input_data1.astype('float32'))

View File

@ -96,7 +96,7 @@ def _identity_fn(x):
def _handle_array_process_allgather(inp, tiled):
if isinstance(inp, array.ArrayImpl) and not inp.is_fully_addressable:
reps = sharding.OpShardingSharding(inp.sharding._device_assignment,
reps = sharding.GSPMDSharding(inp.sharding._device_assignment,
sharding._get_replicated_op_sharding())
out = pjit(_identity_fn, out_axis_resources=reps)(inp)
else:

View File

@ -18,13 +18,16 @@
from jax._src.sharding import (
Sharding as Sharding,
XLACompatibleSharding as XLACompatibleSharding,
# TODO(yashkatariya): Deprecate MeshPspecSharding in 3 months.
# TODO(yashkatariya): Remove MeshPspecSharding in 3 months.
MeshPspecSharding as MeshPspecSharding,
# New name of MeshPspecSharding to match PositionalSharding below.
NamedSharding as NamedSharding,
PartitionSpec as PartitionSpec,
SingleDeviceSharding as SingleDeviceSharding,
PmapSharding as PmapSharding,
GSPMDSharding as GSPMDSharding,
# TODO(yashkatariya): Remove OpShardingSharding in 3 months from
# Feb 17, 2023.
OpShardingSharding as OpShardingSharding,
PositionalSharding as PositionalSharding,
)

View File

@ -694,7 +694,7 @@ class ShardingTest(jtu.JaxTestCase):
shape = (8, 4)
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
mps = sharding.NamedSharding(mesh, pspec)
ops = sharding.OpShardingSharding(
ops = sharding.GSPMDSharding(
list(mesh.devices.flat), mps._to_xla_op_sharding(len(shape)))
self.assertDictEqual(
ops.devices_indices_map(shape), mps.devices_indices_map(shape))
@ -771,16 +771,16 @@ class ShardingTest(jtu.JaxTestCase):
op.tile_assignment_dimensions = [4, 1, 2]
op.tile_assignment_devices = [0, 1, 2, 3, 4, 5, 6, 7]
op.replicate_on_last_tile_dim = True
s = sharding.OpShardingSharding(jax.devices(), op)
s = sharding.GSPMDSharding(jax.devices(), op)
self.assertEqual(
repr(s),
'OpShardingSharding({devices=[4,1,2]0,1,2,3,4,5,6,7 '
'GSPMDSharding({devices=[4,1,2]0,1,2,3,4,5,6,7 '
'last_tile_dim_replicate})')
op2 = xc.OpSharding()
op2.type = xc.OpSharding.Type.REPLICATED
s2 = sharding.OpShardingSharding(jax.devices(), op2)
self.assertEqual(repr(s2), 'OpShardingSharding({replicated})')
s2 = sharding.GSPMDSharding(jax.devices(), op2)
self.assertEqual(repr(s2), 'GSPMDSharding({replicated})')
@parameterized.named_parameters(
("mesh_x_y", P("x", "y"), (4, 2), (), False),
@ -876,9 +876,9 @@ class ShardingTest(jtu.JaxTestCase):
op1 = xc.OpSharding()
op1.type = xc.OpSharding.Type.REPLICATED
s6 = jax.sharding.OpShardingSharding([jax.devices()[0]], op1)
s6 = jax.sharding.GSPMDSharding([jax.devices()[0]], op1)
s7 = jax.sharding.OpShardingSharding(jax.devices(), op1)
s7 = jax.sharding.GSPMDSharding(jax.devices(), op1)
# The OpSharding is replicated but the Sharding itself are on different
# devices.
@ -888,7 +888,7 @@ class ShardingTest(jtu.JaxTestCase):
op2.type = xc.OpSharding.Type.OTHER
op2.tile_assignment_devices = [0, 1]
op2.tile_assignment_dimensions = [2, 1]
s8 = jax.sharding.OpShardingSharding(list(mesh2.devices.flat), op2)
s8 = jax.sharding.GSPMDSharding(list(mesh2.devices.flat), op2)
self.assertTrue(s1.is_equivalent_to(s6, 2))
self.assertTrue(s5.is_equivalent_to(s8, 2))
@ -901,7 +901,7 @@ class ShardingTest(jtu.JaxTestCase):
op3.tile_assignment_devices = [0, 1]
op3.tile_assignment_dimensions = [1, 1, 2]
op3.replicate_on_last_tile_dim = True
s10 = jax.sharding.OpShardingSharding(list(mesh2.devices.flat), op3)
s10 = jax.sharding.GSPMDSharding(list(mesh2.devices.flat), op3)
self.assertTrue(s9.is_equivalent_to(s10, 2))

View File

@ -185,7 +185,7 @@ class PickleTest(jtu.JaxTestCase):
def test_pickle_op_sharding_sharding(self):
op_sharding = xla.xc.OpSharding()
op_sharding.type = xla.xc.OpSharding.Type.REPLICATED
s = jax.sharding.OpShardingSharding(jax.devices(), op_sharding)
s = jax.sharding.GSPMDSharding(jax.devices(), op_sharding)
self.assertEqual(s, pickle.loads(pickle.dumps(s)))
@unittest.skipIf(cloudpickle is None, "Requires cloudpickle")

View File

@ -44,7 +44,7 @@ from jax.experimental import global_device_array
from jax.experimental import multihost_utils
from jax.experimental.custom_partitioning import custom_partitioning
from jax._src import array
from jax._src.sharding import NamedSharding, Sharding, OpShardingSharding
from jax._src.sharding import NamedSharding, Sharding, GSPMDSharding
import jax._src.pjit as pjit_lib
from jax._src.pjit import (pjit, pjit_p, FROM_GDA, AUTO)
from jax._src.interpreters import pxla
@ -1860,7 +1860,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
def _checks(out, input_data):
self.assertIsInstance(out, array.ArrayImpl)
self.assertIsInstance(out.sharding, OpShardingSharding)
self.assertIsInstance(out.sharding, GSPMDSharding)
self.assertEqual(out.shape, (8, 2))
self.assertEqual(out.addressable_shards[0].data.shape, (2, 1))
for s in out.addressable_shards:
@ -2416,20 +2416,20 @@ class ArrayPjitTest(jtu.JaxTestCase):
f = pjit(lambda x: x)
out1 = f(arr)
self.assertIsInstance(out1.sharding, OpShardingSharding)
self.assertIsInstance(out1.sharding, GSPMDSharding)
out1.sharding.devices_indices_map(shape)
cache_info1 = OpShardingSharding.devices_indices_map.cache_info()
cache_info1 = GSPMDSharding.devices_indices_map.cache_info()
out2 = f(out1)
self.assertIsInstance(out2.sharding, OpShardingSharding)
self.assertIsInstance(out2.sharding, GSPMDSharding)
out2.sharding.devices_indices_map(shape)
cache_info2 = OpShardingSharding.devices_indices_map.cache_info()
cache_info2 = GSPMDSharding.devices_indices_map.cache_info()
self.assertEqual(cache_info2.hits, cache_info1.hits + 1)
out3 = f(out2)
self.assertIsInstance(out3.sharding, OpShardingSharding)
self.assertIsInstance(out3.sharding, GSPMDSharding)
out3.sharding.devices_indices_map(shape)
cache_info3 = OpShardingSharding.devices_indices_map.cache_info()
cache_info3 = GSPMDSharding.devices_indices_map.cache_info()
self.assertEqual(cache_info3.hits, cache_info2.hits + 1)
@jax_array(True)
@ -3871,21 +3871,21 @@ class UtilTest(jtu.JaxTestCase):
shape = (8, 4)
devices = jax.devices()
ops = OpShardingSharding(devices, op1)
ops = GSPMDSharding(devices, op1)
ops.devices_indices_map(shape)
cache_info1 = OpShardingSharding.devices_indices_map.cache_info()
cache_info1 = GSPMDSharding.devices_indices_map.cache_info()
ops.devices_indices_map(shape)
cache_info2 = OpShardingSharding.devices_indices_map.cache_info()
cache_info2 = GSPMDSharding.devices_indices_map.cache_info()
self.assertEqual(cache_info2.hits, cache_info1.hits + 1)
ops = OpShardingSharding(devices, op2)
ops = GSPMDSharding(devices, op2)
ops.devices_indices_map(shape)
cache_info3 = OpShardingSharding.devices_indices_map.cache_info()
cache_info3 = GSPMDSharding.devices_indices_map.cache_info()
self.assertEqual(cache_info3.hits, cache_info2.hits + 1)
ops.devices_indices_map(shape)
cache_info4 = OpShardingSharding.devices_indices_map.cache_info()
cache_info4 = GSPMDSharding.devices_indices_map.cache_info()
self.assertEqual(cache_info4.hits, cache_info3.hits + 1)