Replace cached function get_replicated_hlo_sharding() with a constant.

Small cleanup, no functional changes intended.

PiperOrigin-RevId: 737727727
This commit is contained in:
Peter Hawkins 2025-03-17 13:16:52 -07:00 committed by jax authors
parent ebcae0d30a
commit 20658fabb3
2 changed files with 6 additions and 8 deletions

View File

@ -446,7 +446,7 @@ def _inspect_sharding_lowering_rule(ctx: mlir.LoweringRuleContext, value, *,
if len(devices) == 1:
# If we only have one device in our computation, we can construct a
# replicated HloSharding and call it right now.
_hlo_sharding_callback(sharding_impls.get_replicated_hlo_sharding())
_hlo_sharding_callback(sharding_impls.replicated_hlo_sharding)
return []
key = xc.encode_inspect_sharding_callback(_hlo_sharding_callback)

View File

@ -114,9 +114,7 @@ def modify_sdy_sharding_wrt_axis_types(sdy_sharding: SdyArraySharding, mesh):
return sdy_sharding
@util.cache(max_size=128, trace_context_in_key=False)
def get_replicated_hlo_sharding():
return xc.HloSharding.replicate()
replicated_hlo_sharding = xc.HloSharding.replicate()
@use_cpp_class(xc.SingleDeviceSharding)
@ -183,7 +181,7 @@ class SingleDeviceSharding(jsharding.Sharding):
return (self._device,)
def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding:
return get_replicated_hlo_sharding()
return replicated_hlo_sharding
def _to_sdy_sharding(self, num_dimensions: int) -> SdyArraySharding:
sdy_dim_sharding = [SdyDimSharding(axes=[], is_closed=True)
@ -401,7 +399,7 @@ def _op_sharding_to_pos_sharding(
def _positional_sharding_to_xla_hlo_sharding(
self, num_dimensions: int) -> xc.HloSharding:
if self.shape == (1,) * self.ndim:
return get_replicated_hlo_sharding()
return replicated_hlo_sharding
pbuf = xc.OpSharding()
shape = self.shape[self.ndim - num_dimensions:] # 'rank promotion' of val
@ -603,7 +601,7 @@ class GSPMDSharding(jsharding.Sharding):
@functools.cached_property
def _hlo_sharding_hash(self):
if self.is_fully_replicated:
return hash(get_replicated_hlo_sharding())
return hash(replicated_hlo_sharding)
return hash(self._hlo_sharding)
def __eq__(self, other):
@ -669,7 +667,7 @@ class GSPMDSharding(jsharding.Sharding):
@classmethod
def get_replicated(cls, device_assignment, *, memory_kind: str | None = None):
return cls(tuple(device_assignment), get_replicated_hlo_sharding(),
return cls(tuple(device_assignment), replicated_hlo_sharding,
memory_kind=memory_kind)