mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add sharding mismatch to explain_tracing_cache_miss
PiperOrigin-RevId: 730645598
This commit is contained in:
parent
d286733399
commit
6f8bab3c92
@ -1880,13 +1880,15 @@ class ShapedArray(UnshapedArray):
|
||||
return ShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype),
|
||||
self.weak_type, sharding=self.sharding)
|
||||
|
||||
def str_short(self, short_dtypes=False):
|
||||
def str_short(self, short_dtypes=False, mesh_axis_types=False):
|
||||
dt_str = (dtypes.short_dtype_name(self.dtype) if short_dtypes else
|
||||
self.dtype.name)
|
||||
dt_str = dt_str.replace('void', 'float0')
|
||||
if self.sharding is not None:
|
||||
shapestr = _get_shape_sharding_str(self.shape, self.sharding.spec)
|
||||
return f'{dt_str}[{shapestr}]'
|
||||
mesh_axes = (f'({self.sharding.mesh.axis_types})'
|
||||
if mesh_axis_types else '')
|
||||
return f'{dt_str}[{shapestr}]{mesh_axes}'
|
||||
else:
|
||||
shapestr = ','.join(map(str, self.shape))
|
||||
return f'{dt_str}[{shapestr}]'
|
||||
|
@ -1262,6 +1262,9 @@ def explain_tracing_cache_miss(
|
||||
s1 += f'{{weak_type={ty1.weak_type}}}'
|
||||
s2 += f'{{weak_type={ty2.weak_type}}}'
|
||||
add_weak_type_hint = True
|
||||
elif ty1.sharding != ty2.sharding:
|
||||
s1 = ty1.str_short(short_dtypes=True, mesh_axis_types=True)
|
||||
s2 = ty2.str_short(short_dtypes=True, mesh_axis_types=True)
|
||||
else:
|
||||
s1, s2 = str(ty1), str(ty2)
|
||||
p(f" * at {name}, seen {s1}, but now given {s2}")
|
||||
|
@ -43,6 +43,7 @@ from jax._src import prng
|
||||
from jax.sharding import PartitionSpec as P, Mesh
|
||||
from jax.experimental import multihost_utils
|
||||
from jax.experimental.shard_map import shard_map
|
||||
from jax._src.compilation_cache import is_persistent_cache_enabled
|
||||
from jax.experimental.custom_partitioning import (
|
||||
custom_partitioning, SdyShardingRule, BATCHING)
|
||||
from jax._src import array
|
||||
@ -3297,6 +3298,33 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
pjit(_pmapped_fun)(inputs) # doesn't crash
|
||||
jax.jit(_pmapped_fun)(inputs) # doesn't crash
|
||||
|
||||
@jtu.thread_unsafe_test() # logging is not thread-safe
|
||||
def test_cache_miss_explanations_sharding_mismatch(self):
|
||||
mesh = jtu.create_mesh((2,), ('x',))
|
||||
s = NamedSharding(mesh, P('x'))
|
||||
|
||||
@jax.jit
|
||||
def f(x, y):
|
||||
return x * y
|
||||
|
||||
np_inp = np.arange(8, dtype=np.float32)
|
||||
x = np_inp
|
||||
y = jax.device_put(np_inp, s)
|
||||
f(x, y)
|
||||
|
||||
expected_log_len = 1 if not is_persistent_cache_enabled() else 3
|
||||
|
||||
# sharding change
|
||||
with config.explain_cache_misses(True):
|
||||
x_ = jax.device_put(np_inp, s)
|
||||
with self.assertLogs(level='WARNING') as cm:
|
||||
f(x_, y)
|
||||
self.assertLen(cm.output, expected_log_len)
|
||||
msg = cm.output[0]
|
||||
self.assertIn('never seen input type signature', msg)
|
||||
self.assertIn('closest seen input type signature has 1 mismatches', msg)
|
||||
self.assertIn("seen f32[8]({}), but now given f32[8]({Auto: ('x',)})", msg)
|
||||
|
||||
def test_pjit_function_cache_cpp(self):
|
||||
def f(x):
|
||||
return x * 2
|
||||
|
Loading…
x
Reference in New Issue
Block a user