mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Add weak_type to ShapeDtypeStruct because jax.Array also has it and SDS is a duck of jax.Array
This fixes a tracing cache miss issue when you eval shape with a weak_type input and get a strong type output back and pass that back in leading to a cache miss. Fixes: https://github.com/google/jax/issues/23302 PiperOrigin-RevId: 668949430
This commit is contained in:
parent
1dff3a2c71
commit
dd6f0e2e2e
@ -2693,10 +2693,11 @@ class ShapeDtypeStruct:
|
||||
dtype: a dtype-like object
|
||||
sharding: (optional) a :class:`jax.Sharding` object
|
||||
"""
|
||||
__slots__ = ["shape", "dtype", "sharding", "_dll"]
|
||||
__slots__ = ["shape", "dtype", "sharding", "_dll", "weak_type"]
|
||||
named_shape = {} # type: ignore
|
||||
|
||||
def __init__(self, shape, dtype, named_shape=None, sharding=None):
|
||||
def __init__(self, shape, dtype, named_shape=None, sharding=None,
|
||||
weak_type=False):
|
||||
del named_shape # ignored, vestigial
|
||||
self.shape = tuple(shape)
|
||||
if dtype is None:
|
||||
@ -2714,6 +2715,7 @@ class ShapeDtypeStruct:
|
||||
f" layout in a `ShapeDtypeStruct`. Got {sharding}")
|
||||
self.sharding = sharding.sharding if isinstance(sharding, Layout) else sharding
|
||||
self._dll = sharding.device_local_layout if isinstance(sharding, Layout) else None
|
||||
self.weak_type = weak_type
|
||||
|
||||
size = property(lambda self: math.prod(self.shape))
|
||||
ndim = property(lambda self: len(self.shape))
|
||||
@ -2731,8 +2733,9 @@ class ShapeDtypeStruct:
|
||||
def __repr__(self):
|
||||
sh = f", sharding={self.sharding}" if self.sharding is not None else ""
|
||||
l = f", layout={self.layout}" if self._dll is not None else ""
|
||||
wt = f", weak_type={self.weak_type}" if self.weak_type else ""
|
||||
return (f"{type(self).__name__}(shape={self.shape}, "
|
||||
f"dtype={self.dtype.name}{sh}{l})")
|
||||
f"dtype={self.dtype.name}{sh}{l}{wt})")
|
||||
|
||||
__str__ = __repr__
|
||||
|
||||
@ -2740,17 +2743,19 @@ class ShapeDtypeStruct:
|
||||
if not isinstance(other, ShapeDtypeStruct):
|
||||
return False
|
||||
else:
|
||||
return ((other.shape, other.dtype, other.sharding, other.layout) ==
|
||||
(self.shape, self.dtype, self.sharding, self.layout))
|
||||
return ((self.shape, self.dtype, self.sharding, self.layout, self.weak_type) ==
|
||||
(other.shape, other.dtype, other.sharding, other.layout, other.weak_type))
|
||||
|
||||
def __hash__(self):
|
||||
# TODO(frostig): avoid the conversion from dict by addressing
|
||||
# https://github.com/google/jax/issues/8182
|
||||
return hash((self.shape, self.dtype, self.sharding, self.layout))
|
||||
return hash((self.shape, self.dtype, self.sharding, self.layout, self.weak_type))
|
||||
|
||||
core.pytype_aval_mappings[ShapeDtypeStruct] = (
|
||||
lambda x: ShapedArray(x.shape, dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True),
|
||||
weak_type=False))
|
||||
def _sds_aval_mapping(x):
|
||||
return ShapedArray(
|
||||
x.shape, dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True),
|
||||
weak_type=x.weak_type)
|
||||
core.pytype_aval_mappings[ShapeDtypeStruct] = _sds_aval_mapping
|
||||
|
||||
|
||||
@api_boundary
|
||||
|
@ -524,7 +524,8 @@ def _make_jit_wrapper(fun: Callable, jit_info: PjitInfo):
|
||||
p, _ = _infer_params(fun, jit_info, args, kwargs)
|
||||
out_s = [None if is_unspecified(s) else s for s in p.params['out_shardings']]
|
||||
# TODO(yashkatariya): Add `Layout` to SDS.
|
||||
out = [api.ShapeDtypeStruct(x.shape, x.dtype, sharding=s)
|
||||
out = [api.ShapeDtypeStruct(x.shape, x.dtype, sharding=s,
|
||||
weak_type=x.weak_type)
|
||||
for x, s in zip(p.params['jaxpr'].out_avals, out_s)]
|
||||
return tree_unflatten(p.out_tree, out)
|
||||
|
||||
|
@ -4099,6 +4099,18 @@ class APITest(jtu.JaxTestCase):
|
||||
a2 = jnp.array(((x, x), [x, x]))
|
||||
self.assertAllClose(np.array(((1, 1), (1, 1))), a2)
|
||||
|
||||
def test_eval_shape_weak_type(self):
|
||||
# https://github.com/google/jax/issues/23302
|
||||
arr = jax.numpy.array(1)
|
||||
|
||||
with jtu.count_jit_tracing_cache_miss() as count:
|
||||
jax.eval_shape(jax.numpy.array, 1)
|
||||
out = jax.eval_shape(jax.numpy.array, 1)
|
||||
|
||||
self.assertEqual(count[0], 1)
|
||||
self.assertTrue(out.weak_type)
|
||||
self.assertEqual(out.weak_type, arr.weak_type)
|
||||
|
||||
def test_dunder_jax_array_bug(self):
|
||||
@jax.tree_util.register_pytree_node_class
|
||||
class A:
|
||||
|
Loading…
x
Reference in New Issue
Block a user