Add weak_type to ShapeDtypeStruct because jax.Array also has it and SDS is a duck of jax.Array

This fixes a tracing cache miss issue when you eval shape with a weak_type input and get a strong type output back and pass that back in leading to a cache miss.

Fixes: https://github.com/google/jax/issues/23302
PiperOrigin-RevId: 668949430
This commit is contained in:
Yash Katariya 2024-08-29 08:35:00 -07:00 committed by jax authors
parent 1dff3a2c71
commit dd6f0e2e2e
3 changed files with 28 additions and 10 deletions

View File

@ -2693,10 +2693,11 @@ class ShapeDtypeStruct:
dtype: a dtype-like object
sharding: (optional) a :class:`jax.Sharding` object
"""
__slots__ = ["shape", "dtype", "sharding", "_dll"]
__slots__ = ["shape", "dtype", "sharding", "_dll", "weak_type"]
named_shape = {} # type: ignore
def __init__(self, shape, dtype, named_shape=None, sharding=None):
def __init__(self, shape, dtype, named_shape=None, sharding=None,
weak_type=False):
del named_shape # ignored, vestigial
self.shape = tuple(shape)
if dtype is None:
@ -2714,6 +2715,7 @@ class ShapeDtypeStruct:
f" layout in a `ShapeDtypeStruct`. Got {sharding}")
self.sharding = sharding.sharding if isinstance(sharding, Layout) else sharding
self._dll = sharding.device_local_layout if isinstance(sharding, Layout) else None
self.weak_type = weak_type
size = property(lambda self: math.prod(self.shape))
ndim = property(lambda self: len(self.shape))
@ -2731,8 +2733,9 @@ class ShapeDtypeStruct:
def __repr__(self):
sh = f", sharding={self.sharding}" if self.sharding is not None else ""
l = f", layout={self.layout}" if self._dll is not None else ""
wt = f", weak_type={self.weak_type}" if self.weak_type else ""
return (f"{type(self).__name__}(shape={self.shape}, "
f"dtype={self.dtype.name}{sh}{l})")
f"dtype={self.dtype.name}{sh}{l}{wt})")
__str__ = __repr__
@ -2740,17 +2743,19 @@ class ShapeDtypeStruct:
if not isinstance(other, ShapeDtypeStruct):
return False
else:
return ((other.shape, other.dtype, other.sharding, other.layout) ==
(self.shape, self.dtype, self.sharding, self.layout))
return ((self.shape, self.dtype, self.sharding, self.layout, self.weak_type) ==
(other.shape, other.dtype, other.sharding, other.layout, other.weak_type))
def __hash__(self):
# TODO(frostig): avoid the conversion from dict by addressing
# https://github.com/google/jax/issues/8182
return hash((self.shape, self.dtype, self.sharding, self.layout))
return hash((self.shape, self.dtype, self.sharding, self.layout, self.weak_type))
core.pytype_aval_mappings[ShapeDtypeStruct] = (
lambda x: ShapedArray(x.shape, dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True),
weak_type=False))
def _sds_aval_mapping(x):
return ShapedArray(
x.shape, dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True),
weak_type=x.weak_type)
core.pytype_aval_mappings[ShapeDtypeStruct] = _sds_aval_mapping
@api_boundary

View File

@ -524,7 +524,8 @@ def _make_jit_wrapper(fun: Callable, jit_info: PjitInfo):
p, _ = _infer_params(fun, jit_info, args, kwargs)
out_s = [None if is_unspecified(s) else s for s in p.params['out_shardings']]
# TODO(yashkatariya): Add `Layout` to SDS.
out = [api.ShapeDtypeStruct(x.shape, x.dtype, sharding=s)
out = [api.ShapeDtypeStruct(x.shape, x.dtype, sharding=s,
weak_type=x.weak_type)
for x, s in zip(p.params['jaxpr'].out_avals, out_s)]
return tree_unflatten(p.out_tree, out)

View File

@ -4099,6 +4099,18 @@ class APITest(jtu.JaxTestCase):
a2 = jnp.array(((x, x), [x, x]))
self.assertAllClose(np.array(((1, 1), (1, 1))), a2)
def test_eval_shape_weak_type(self):
# https://github.com/google/jax/issues/23302
arr = jax.numpy.array(1)
with jtu.count_jit_tracing_cache_miss() as count:
jax.eval_shape(jax.numpy.array, 1)
out = jax.eval_shape(jax.numpy.array, 1)
self.assertEqual(count[0], 1)
self.assertTrue(out.weak_type)
self.assertEqual(out.weak_type, arr.weak_type)
def test_dunder_jax_array_bug(self):
@jax.tree_util.register_pytree_node_class
class A: