From 20658fabb3a2c01ddfec648a6df91bfaa7c27050 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 17 Mar 2025 13:16:52 -0700 Subject: [PATCH] Replace cached function get_replicated_hlo_sharding() with a constant. Small cleanup, no functional changes intended. PiperOrigin-RevId: 737727727 --- jax/_src/debugging.py | 2 +- jax/_src/sharding_impls.py | 12 +++++------- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/jax/_src/debugging.py b/jax/_src/debugging.py index 54ac2d5fd..b61b28e12 100644 --- a/jax/_src/debugging.py +++ b/jax/_src/debugging.py @@ -446,7 +446,7 @@ def _inspect_sharding_lowering_rule(ctx: mlir.LoweringRuleContext, value, *, if len(devices) == 1: # If we only have one device in our computation, we can construct a # replicated HloSharding and call it right now. - _hlo_sharding_callback(sharding_impls.get_replicated_hlo_sharding()) + _hlo_sharding_callback(sharding_impls.replicated_hlo_sharding) return [] key = xc.encode_inspect_sharding_callback(_hlo_sharding_callback) diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index 019411c77..60e8c54a4 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -114,9 +114,7 @@ def modify_sdy_sharding_wrt_axis_types(sdy_sharding: SdyArraySharding, mesh): return sdy_sharding -@util.cache(max_size=128, trace_context_in_key=False) -def get_replicated_hlo_sharding(): - return xc.HloSharding.replicate() +replicated_hlo_sharding = xc.HloSharding.replicate() @use_cpp_class(xc.SingleDeviceSharding) @@ -183,7 +181,7 @@ class SingleDeviceSharding(jsharding.Sharding): return (self._device,) def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding: - return get_replicated_hlo_sharding() + return replicated_hlo_sharding def _to_sdy_sharding(self, num_dimensions: int) -> SdyArraySharding: sdy_dim_sharding = [SdyDimSharding(axes=[], is_closed=True) @@ -401,7 +399,7 @@ def _op_sharding_to_pos_sharding( def _positional_sharding_to_xla_hlo_sharding( self, num_dimensions: int) -> xc.HloSharding: if self.shape == (1,) * self.ndim: - return get_replicated_hlo_sharding() + return replicated_hlo_sharding pbuf = xc.OpSharding() shape = self.shape[self.ndim - num_dimensions:] # 'rank promotion' of val @@ -603,7 +601,7 @@ class GSPMDSharding(jsharding.Sharding): @functools.cached_property def _hlo_sharding_hash(self): if self.is_fully_replicated: - return hash(get_replicated_hlo_sharding()) + return hash(replicated_hlo_sharding) return hash(self._hlo_sharding) def __eq__(self, other): @@ -669,7 +667,7 @@ class GSPMDSharding(jsharding.Sharding): @classmethod def get_replicated(cls, device_assignment, *, memory_kind: str | None = None): - return cls(tuple(device_assignment), get_replicated_hlo_sharding(), + return cls(tuple(device_assignment), replicated_hlo_sharding, memory_kind=memory_kind)