Add sharding mismatch to explain_tracing_cache_miss

PiperOrigin-RevId: 730645598
This commit is contained in:
Yash Katariya 2025-02-24 16:49:15 -08:00 committed by jax authors
parent d286733399
commit 6f8bab3c92
3 changed files with 35 additions and 2 deletions

View File

@ -1880,13 +1880,15 @@ class ShapedArray(UnshapedArray):
return ShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype),
self.weak_type, sharding=self.sharding)
def str_short(self, short_dtypes=False):
def str_short(self, short_dtypes=False, mesh_axis_types=False):
dt_str = (dtypes.short_dtype_name(self.dtype) if short_dtypes else
self.dtype.name)
dt_str = dt_str.replace('void', 'float0')
if self.sharding is not None:
shapestr = _get_shape_sharding_str(self.shape, self.sharding.spec)
return f'{dt_str}[{shapestr}]'
mesh_axes = (f'({self.sharding.mesh.axis_types})'
if mesh_axis_types else '')
return f'{dt_str}[{shapestr}]{mesh_axes}'
else:
shapestr = ','.join(map(str, self.shape))
return f'{dt_str}[{shapestr}]'

View File

@ -1262,6 +1262,9 @@ def explain_tracing_cache_miss(
s1 += f'{{weak_type={ty1.weak_type}}}'
s2 += f'{{weak_type={ty2.weak_type}}}'
add_weak_type_hint = True
elif ty1.sharding != ty2.sharding:
s1 = ty1.str_short(short_dtypes=True, mesh_axis_types=True)
s2 = ty2.str_short(short_dtypes=True, mesh_axis_types=True)
else:
s1, s2 = str(ty1), str(ty2)
p(f" * at {name}, seen {s1}, but now given {s2}")

View File

@ -43,6 +43,7 @@ from jax._src import prng
from jax.sharding import PartitionSpec as P, Mesh
from jax.experimental import multihost_utils
from jax.experimental.shard_map import shard_map
from jax._src.compilation_cache import is_persistent_cache_enabled
from jax.experimental.custom_partitioning import (
custom_partitioning, SdyShardingRule, BATCHING)
from jax._src import array
@ -3297,6 +3298,33 @@ class ArrayPjitTest(jtu.JaxTestCase):
pjit(_pmapped_fun)(inputs) # doesn't crash
jax.jit(_pmapped_fun)(inputs) # doesn't crash
@jtu.thread_unsafe_test() # logging is not thread-safe
def test_cache_miss_explanations_sharding_mismatch(self):
mesh = jtu.create_mesh((2,), ('x',))
s = NamedSharding(mesh, P('x'))
@jax.jit
def f(x, y):
return x * y
np_inp = np.arange(8, dtype=np.float32)
x = np_inp
y = jax.device_put(np_inp, s)
f(x, y)
expected_log_len = 1 if not is_persistent_cache_enabled() else 3
# sharding change
with config.explain_cache_misses(True):
x_ = jax.device_put(np_inp, s)
with self.assertLogs(level='WARNING') as cm:
f(x_, y)
self.assertLen(cm.output, expected_log_len)
msg = cm.output[0]
self.assertIn('never seen input type signature', msg)
self.assertIn('closest seen input type signature has 1 mismatches', msg)
self.assertIn("seen f32[8]({}), but now given f32[8]({Auto: ('x',)})", msg)
def test_pjit_function_cache_cpp(self):
def f(x):
return x * 2