1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-20 05:46:06 +00:00

Merge pull request from mattjj:pvary-errors

PiperOrigin-RevId: 746070501
This commit is contained in:
jax authors 2025-04-10 09:56:16 -07:00
commit 9011d66a29
5 changed files with 303 additions and 96 deletions

@ -260,3 +260,24 @@ def _show_diff(array1, array2):
def _avals_short(avals):
to_str = lambda aval: getattr(aval, 'str_short', partial(str, aval))()
return ' '.join(map(to_str, avals))
def _aval_mismatch_extra(a1: core.AbstractValue, a2: core.AbstractValue) -> str:
assert not core.typematch(a1, a2)
if isinstance(a1, core.ShapedArray) and isinstance(a2, core.ShapedArray):
mismatches = []
if a1.dtype != a2.dtype:
mismatches.append('the dtypes do not match')
if a1.shape != a2.shape:
mismatches.append('the shapes do not match')
if a1.vma != a2.vma:
mismatches.append('the varying manual axes do not match')
# TODO(yashkatariya,mattjj): add check for sharding-in-types mismatch
if len(mismatches) == 0:
return ''
elif len(mismatches) == 1:
return ', so ' + mismatches[0]
else:
return ', so ' + ', '.join(mismatches[:-1]) + ', and ' + mismatches[-1]
return ''

@ -23,7 +23,9 @@ import itertools
import operator
from typing import Any, TypeVar
from jax.tree_util import tree_flatten, tree_unflatten
from jax._src.tree_util import (
tree_flatten, tree_unflatten, tree_flatten_with_path, keystr,
equality_errors_pytreedef)
from jax._src import ad_util
from jax._src import api_util
from jax._src import config
@ -44,19 +46,14 @@ from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import xla
from jax._src.lax import lax
from jax._src.traceback_util import api_boundary
from jax._src.util import (safe_map, split_list, partition_list)
from jax._src.util import safe_map, split_list, partition_list, unzip2
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
import numpy as np
from jax._src.lax.control_flow.common import (
_avals_short,
_check_tree_and_avals,
_initial_style_jaxprs_with_common_consts,
_make_closed_jaxpr,
_prune_zeros,
_typecheck_param,
)
_avals_short, _typecheck_param, _aval_mismatch_extra,
_initial_style_jaxprs_with_common_consts, _make_closed_jaxpr, _prune_zeros)
map, unsafe_map = safe_map, map
@ -147,10 +144,9 @@ def switch(index, branches: Sequence[Callable], *operands,
if config.mutable_array_checks.value:
api_util._check_no_aliased_closed_over_refs(dbgs[0], (*jaxprs[0].consts, *consts), ops)
for i, (out_tree, jaxpr) in enumerate(zip(out_trees[1:], jaxprs[1:])):
_check_tree_and_avals("branch 0 output",
out_trees[0], jaxprs[0].out_avals,
f"branch {i + 1} output",
out_tree, jaxpr.out_avals)
_check_branch_outputs(
"switch", "branch 0", f"branch{i+1}", branches[0], branches[i+1],
out_trees[0], out_tree, jaxprs[0].out_avals, jaxpr.out_avals)
# prune passthrough outputs
fwds = [pe._jaxpr_forwarding(jaxpr.jaxpr) for jaxpr in jaxprs]
in_fwd = [xs[0] if len(set(xs)) == 1 else None for xs in zip(*fwds)]
@ -270,10 +266,10 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
true_jaxpr.out_avals + false_jaxpr.out_avals):
raise ValueError("Cannot return `Ref`s from `cond`.")
_check_tree_and_avals("true_fun output",
out_tree, true_jaxpr.out_avals,
"false_fun output",
false_out_tree, false_jaxpr.out_avals)
_check_branch_outputs(
'cond', 'true_fun', 'false_fun', true_fun, false_fun, out_tree,
false_out_tree, true_jaxpr.out_avals, false_jaxpr.out_avals)
# prune passthrough outputs
true_fwds = pe._jaxpr_forwarding(true_jaxpr.jaxpr)
false_fwds = pe._jaxpr_forwarding(false_jaxpr.jaxpr)
@ -303,6 +299,90 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
assert next(out_, None) is None
return tree_unflatten(out_tree, out)
def _check_branch_outputs(
api_name, name1, name2, f1, f2, out_tree1, out_tree2, out_avals1,
out_avals2) -> None:
info1 = api_util.fun_sourceinfo(f1)
info2 = api_util.fun_sourceinfo(f2)
try:
outs1 = tree_unflatten(out_tree1, out_avals1)
except:
paths = [None] * len(out_avals1)
component = lambda _: ''
else:
leaves_and_paths, _ = tree_flatten_with_path(outs1)
paths, _ = unzip2(leaves_and_paths)
component = lambda p: f' at path {keystr(p)}' if p else ''
if out_tree1 != out_tree2:
diffs = [f'{name1} output{component(p)} is a {thing1} but '
f'{name2} output{component(p)} is a {thing2}, so {expl}'
for p, thing1, thing2, expl
in equality_errors_pytreedef(out_tree1, out_tree2)]
if len(diffs) == 0:
return # the trees may have different aux data, but structures are same
elif len(diffs) == 1:
differences = f'{diffs[0]}.\n'
else:
differences = ('\n'.join(f' * {d};\n' for d in diffs[:-1])
+ f' * {diffs[-1]}.\n')
raise TypeError(
f'{api_name} branch outputs must have the same pytree structure, but '
'they differ:\n\n'
f'{name1} is {info1}\n' + f'{name2} is {info2}\n\n'
f'{differences}\n'
f'Revise {name1} and/or {name2} so that they have the same pytree '
'structure.')
if not all(map(core.typematch, out_avals1, out_avals2)):
diffs = [f'the output of {name1}{component(p)} has type {a1.str_short()}'
f' but the corresponding output of {name2} has type '
f'{a2.str_short()}{_aval_mismatch_extra(a1, a2)}'
for p, a1, a2 in zip(paths, out_avals1, out_avals2)
if not core.typematch(a1, a2)]
if len(diffs) == 0:
return # seems unreachable but in any case we don't have a good error msg
elif len(diffs) == 1:
differences = f'{_capitalize(diffs[0])}.\n'
else:
differences = ('\n'.join(f' * {d};' for d in diffs[:-1])
+ f'\n * {diffs[-1]}.\n')
pvary_applications = [
f'applying `jax.lax.pvary(..., {tuple(a1.vma - a2.vma)})` '
f'to the output of {n}{component(p)}'
for p, aval1, aval2 in zip(paths, out_avals1, out_avals2)
for n, a1, a2 in [(name1, aval2, aval1), (name2, aval1, aval2)]
if not core.typematch(a1, a2) and
isinstance(a1, core.ShapedArray) and isinstance(a2, core.ShapedArray)
and a1.vma != a2.vma and a2.vma - a1.vma]
if not pvary_applications:
pvary_msg = ''
elif len(pvary_applications) == 1:
pvary_msg = f'This might be fixed by {pvary_applications[0]}.\n'
else:
pvary_msg = ('This might be fixed by:\n' +
'\n'.join(f' * {d};' for d in pvary_applications[:-1])
+ f'\n * {pvary_applications[-1]}.\n')
if pvary_msg:
pvary_msg += ("See https://docs.jax.dev/en/latest/notebooks/shard_map.html#scan-vma "
"for more information.\n\n")
raise TypeError(
f'{api_name} branches must have equal output types but they differ.\n\n'
f'{name1} is {info1}\n' + f'{name2} is {info2}\n\n'
f'{differences}\n'
f'{pvary_msg}'
f'Revise {name1} and/or {name2} so that all output types match.')
def _capitalize(s):
# s.capitalize() converts s[1:] to lowercase which we don't want.
return s[0].capitalize() + s[1:]
@api_boundary
@functools.wraps(_cond)
def cond(*args, **kwargs):

@ -51,7 +51,7 @@ from jax._src.lax import windowed_reductions
from jax._src.lax.control_flow.common import (
_avals_short, _initial_style_jaxpr,
_initial_style_jaxpr_attrs, _make_closed_jaxpr_attrs, _prune_zeros,
_typecheck_param)
_typecheck_param, _aval_mismatch_extra)
from jax._src.lax.other import logaddexp
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
@ -60,24 +60,12 @@ from jax._src.traceback_util import api_boundary
from jax._src.tree_util import equality_errors
from jax._src.typing import Array
from jax._src.util import (
merge_lists,
partition_list,
safe_map,
safe_zip,
split_list,
split_list_checked,
unzip2,
weakref_lru_cache,
)
merge_lists, partition_list, safe_map, safe_zip, split_list,
split_list_checked, unzip2, weakref_lru_cache,)
from jax._src import xla_bridge as xb
from jax.tree_util import (
keystr,
tree_flatten,
tree_flatten_with_path,
tree_map,
tree_unflatten,
treedef_is_leaf,
)
keystr, tree_flatten, tree_flatten_with_path, tree_map, tree_unflatten,
treedef_is_leaf)
import numpy as np
_map = safe_map
@ -428,9 +416,8 @@ def _check_carry_type(name, body_fun, in_carry, out_carry_tree, out_avals):
for path, thing1, thing2, explanation
in equality_errors(in_carry, out_carry)]
if len(diffs) == 0:
# The trees may have different aux data but structures are the same.
return
if len(diffs) == 1:
return # the trees may have different aux data, but structures are same
elif len(diffs) == 1:
differences = f'{_capitalize(diffs[0])}.\n'
else:
differences = ('\n'.join(f' * {d};\n' for d in diffs[:-1])
@ -447,32 +434,42 @@ def _check_carry_type(name, body_fun, in_carry, out_carry_tree, out_avals):
f'{out_aval.str_short()}{_aval_mismatch_extra(in_aval, out_aval)}'
for path, in_aval, out_aval in zip(paths, in_avals, out_avals)
if not core.typematch(in_aval, out_aval)]
if len(diffs) == 0:
# The trees may have different aux data but structures are the same.
return
return # seems unreachable but in any case we don't have a good error msg
if len(diffs) == 1:
differences = f'{_capitalize(diffs[0])}.\n'
else:
differences = ('\n'.join(f' * {d};\n' for d in diffs[:-1])
+ f' * {diffs[-1]}.\n')
pvary_applications = [
f'applying `jax.lax.pvary(..., {tuple(out_aval.vma - in_aval.vma)})` '
f'to the initial carry value corresponding to {component(path)}'
for path, in_aval, out_aval in zip(paths, in_avals, out_avals)
if not core.typematch(in_aval, out_aval) and
isinstance(in_aval, ShapedArray) and isinstance(out_aval, ShapedArray)
and in_aval.vma != out_aval.vma and out_aval.vma - in_aval.vma]
if not pvary_applications:
pvary_msg = ''
elif len(pvary_applications) == 1:
pvary_msg = f'This might be fixed by {pvary_applications[0]}.\n'
else:
pvary_msg = ('This might be fixed by:\n' +
'\n'.join(f' * {d};\n' for d in pvary_applications[:-1])
+ f' * {pvary_applications[-1]}.\n')
if pvary_msg:
pvary_msg += ("See https://docs.jax.dev/en/latest/notebooks/shard_map.html#scan-vma "
"for more information.\n\n")
raise TypeError(
f"{name} function carry input and carry output must have equal types "
"(e.g. shapes and dtypes of arrays), "
f"{name} function carry input and carry output must have equal types, "
"but they differ:\n\n"
f"{differences}\n"
"Revise the function so that all output types (e.g. shapes "
"and dtypes) match the corresponding input types.")
def _aval_mismatch_extra(a1: core.AbstractValue, a2: core.AbstractValue) -> str:
assert not core.typematch(a1, a2)
if isinstance(a1, core.ShapedArray) and isinstance(a2, core.ShapedArray):
dtype_mismatch = a1.dtype != a2.dtype
shape_mismatch = a1.shape != a2.shape
return (', so ' * (dtype_mismatch or shape_mismatch) +
'the dtypes do not match' * dtype_mismatch +
' and also ' * (dtype_mismatch and shape_mismatch) +
'the shapes do not match' * shape_mismatch)
return ''
f"{pvary_msg}"
"Revise the function so that all output types match the corresponding "
"input types.")
# TODO(mattjj): re-land #19819 version? simpler, but caused ~1 perf regression.
def _scan_impl(*args, reverse, length, num_consts, num_carry, jaxpr, linear,

@ -588,7 +588,6 @@ class LaxControlFlowTest(jtu.JaxTestCase):
init = jnp.float32(10)
self.assertEqual(fori_loop_with_static_upper_and_lower(init), init)
def testForiLoopBatched(self):
def body_fun(i, loop_carry):
x, y = loop_carry
@ -994,16 +993,24 @@ class LaxControlFlowTest(jtu.JaxTestCase):
with self.assertRaisesRegex(TypeError,
re.escape("Pred must be a scalar, got (1.0, 1.0) of type <class 'tuple'>")):
lax.cond((1., 1.), lambda top: 2., lambda fop: 3., 1.)
with self.assertRaisesRegex(TypeError,
re.compile("true_fun output must have same type structure "
"as false_fun output, but there are differences:.*"
r"at output\['a'\], true_fun output has pytree leaf", re.DOTALL)):
lax.cond(True, lambda top: dict(a=2.), lambda fop: dict(a=(3., 3.)), 1.)
with self.assertRaisesRegex(
TypeError,
"true_fun output and false_fun output must have identical types, got\n"
r"DIFFERENT ShapedArray\(float32\[1\]\) vs. "
r"ShapedArray\(float32\[\].*\)."):
re.compile(
r"cond branch outputs must have the same pytree structure, but they"
r" differ:.*true_fun output at path \['a'\] is a pytree leaf but"
r" false_fun output at path \['a'\] is a <class 'tuple'>",
re.DOTALL)):
lax.cond(True, lambda top: dict(a=2.), lambda fop: dict(a=(3., 3.)), 1.)
with self.assertRaisesRegex(
TypeError,
re.compile(
r"cond branches must have equal output types but they differ.*The"
r" output of true_fun has type float32\[1\] but the corresponding"
r" output of false_fun has type float32\[\], so the shapes do not"
r" match",
re.DOTALL)):
lax.cond(True,
lambda top: jnp.array([1.], jnp.float32),
lambda fop: jnp.float32(1.),
@ -1023,16 +1030,26 @@ class LaxControlFlowTest(jtu.JaxTestCase):
with self.assertRaisesRegex(ValueError,
re.escape("Empty branch sequence")):
lax.switch(0, [], 1.)
with self.assertRaisesRegex(TypeError,
re.compile("branch 0 output must have same type structure "
"as branch 1 output, but there are differences:.*"
r"at output\['a'\], branch 0 output has pytree leaf", re.DOTALL)):
lax.switch(1, [lambda _: dict(a=2.), lambda _: dict(a=(3., 3.))], 1.)
with self.assertRaisesRegex(
TypeError,
"branch 0 output and branch 1 output must have identical types, got\n"
r"{'a': 'DIFFERENT ShapedArray\(float32\[1\]\) "
r"vs. ShapedArray\(float32\[\].*\)'}."):
re.compile(
"switch branch outputs must have the same pytree structure, but"
r" they differ.*branch 0 output at path \['a'\] is a pytree leaf"
r" but branch1 output at path \['a'\] is a <class 'tuple'>, so"
r" their"
" Python types differ.",
re.DOTALL)):
lax.switch(1, [lambda _: dict(a=2.), lambda _: dict(a=(3., 3.))], 1.)
with self.assertRaisesRegex(
TypeError,
re.compile(
"switch branches must have equal output types but they differ.*The"
r" output of branch 0 at path \['a'\] has type float32\[1\] but the"
r" corresponding output of branch1 has type float32\[\], so the"
" shapes do not match",
re.DOTALL)):
lax.switch(1, [lambda _: dict(a=jnp.array([1.], jnp.float32)),
lambda _: dict(a=jnp.float32(1.))],
1.)
@ -1983,7 +2000,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
with self.assertRaisesRegex(
TypeError,
re.escape("function carry input and carry output must have equal "
"types (e.g. shapes and dtypes of arrays), but they differ:\n\n"
"types, but they differ:\n\n"
"The input carry x has type int32[] but the corresponding "
"output carry component has type float32[], so the dtypes do "
"not match"
@ -1994,7 +2011,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
with self.assertRaisesRegex(
TypeError,
re.escape("function carry input and carry output must have equal "
"types (e.g. shapes and dtypes of arrays), but they differ:\n\n"
"types, but they differ:\n\n"
"The input carry component x[1] has type int32[] but the "
"corresponding output carry component has type float32[], "
"so the dtypes do not match"
@ -2005,13 +2022,13 @@ class LaxControlFlowTest(jtu.JaxTestCase):
with self.assertRaisesRegex(
TypeError,
re.escape("function carry input and carry output must have equal "
"types (e.g. shapes and dtypes of arrays), but they differ:\n\n"
"types, but they differ:\n\n"
" * the input carry component x[0] has type int32[] but the "
"corresponding output carry component has type float32[], "
"so the dtypes do not match;\n"
" * the input carry component x[1] has type int32[] but the "
"corresponding output carry component has type float32[1,1], "
"so the dtypes do not match and also the shapes do not match."
"so the dtypes do not match, and the shapes do not match."
)):
jax.lax.scan(lambda x, _: ((x[0].astype('float32'),
x[1].astype('float32').reshape(1, 1),
@ -2495,7 +2512,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
self.assertLess(len(scan_unrolled_hlo), len(scan_fully_unrolled_hlo))
# and the lowering should contain a while loop, unless the scan is fully
# unrolled
# unrolled
self.assertIn("while(", scan_hlo)
self.assertIn("while(", scan_unrolled_hlo)
self.assertNotIn("while(", scan_fully_unrolled_hlo)
@ -2786,7 +2803,6 @@ class LaxControlFlowTest(jtu.JaxTestCase):
self.assertAllClose(deriv(my_pow)(3.0, 1), 1.0, check_dtypes=False)
def test_while_loop_fixed_point_with_batched_pred_and_consts(self):
def f(i, x):
def cond(carry):
@ -3076,7 +3092,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
def leak():
data = jax.device_put(np.zeros((1024), dtype=np.float32) + 1)
def g():
return jax.lax.cond(
return jax.lax.cond(
True,
lambda: data[0], # noqa: F821
lambda: data[1], # noqa: F821

@ -1084,11 +1084,9 @@ class ShardMapTest(jtu.JaxTestCase):
def f(x, y):
def true_fn(x, y):
return x
return lax.pvary(x, 'y')
def false_fun(x, y):
return y
x = lax.pvary(x, 'y')
y = lax.pvary(y, 'x')
return lax.pvary(y, 'x')
return jax.lax.cond(True, true_fn, false_fun, x, y)
shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P(('x', 'y')))(x, x)
@ -2809,10 +2807,10 @@ class ShardMapTest(jtu.JaxTestCase):
xs = jax.lax.with_sharding_constraint(xs, NamedSharding(mesh, P('i')))
def fun(v, xs):
# Commenting this single line below makes everything work
v = jax.scipy.linalg.expm(v)
v = v.sum()
return v * xs.sum(axis=-1).astype(v.dtype)
# Commenting this single line below makes everything work
v = jax.scipy.linalg.expm(v)
v = v.sum()
return v * xs.sum(axis=-1).astype(v.dtype)
res = fun(variables, xs)
fun_shard_map = shard_map(fun, mesh=mesh, in_specs=in_specs, out_specs=out_specs)
@ -2831,25 +2829,120 @@ class ShardMapTest(jtu.JaxTestCase):
mesh = jtu.create_mesh((4,), ('i',))
@jax.custom_jvp
def f(a: jax.Array, b: jax.Array) -> jax.Array:
return a + b
return a + b
@partial(f.defjvp, symbolic_zeros=True)
def f_jvp(primals, tangents):
a, b = primals
a_dot, b_dot = tangents
y = f(a, b)
y_dot = jnp.zeros_like(y)
if not isinstance(a_dot, SymbolicZero):
y_dot += a_dot
if not isinstance(b_dot, SymbolicZero):
y_dot += b_dot
return y, y_dot
a, b = primals
a_dot, b_dot = tangents
y = f(a, b)
y_dot = jnp.zeros_like(y)
if not isinstance(a_dot, SymbolicZero):
y_dot += a_dot
if not isinstance(b_dot, SymbolicZero):
y_dot += b_dot
return y, y_dot
x = jax.random.normal(jax.random.key(0), (jax.device_count(), 20))
A = jax.random.normal(jax.random.key(1), (jax.device_count(), 20))
g = shard_map(f, mesh, in_specs=P('i'), out_specs=P('i'))
jax.jvp(lambda x: g(x, A), (x,), (x,)) # don't crash
def test_cond_pvary_errors(self):
mesh = jtu.create_mesh((1, 1), ('x', 'y'))
def f(x, y):
def true_fn(x, y):
return x
def false_fun(x, y):
return y
return jax.lax.cond(True, true_fn, false_fun, x, y)
x = jnp.arange(4.)
with self.assertRaisesRegex(
TypeError,
r"applying `jax.lax.pvary\(..., \('y',\)\)` to the output of true_fun"):
shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P(('x', 'y')))(x, x)
def test_cond_pvary_errors_pytree(self):
mesh = jtu.create_mesh((1, 1), ('x', 'y'))
def f(x, y):
def true_fn(x, y):
return x, y
def false_fun(x, y):
return y, x
return jax.lax.cond(True, true_fn, false_fun, x, y)
x = jnp.arange(4.)
with self.assertRaisesRegex(
TypeError,
r"applying `jax.lax.pvary\(..., \('y',\)\)` to the output of true_fun"):
shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P(('x', 'y')))(x, x)
def test_scan_pvary_errors(self):
mesh = jtu.create_mesh((1, 1), ('i', 'j'))
x = jnp.arange(3.)
y = jnp.arange(3.)
@partial(shard_map, mesh=mesh, in_specs=(P('i'), P()), out_specs=P('i'))
def f(x, y):
def body(carry, _):
c1, c2 = carry
return (c2, c1), () # swap the carry
(x_, y_), _ = jax.lax.scan(body, (x, y), (), length=2)
return x_, y_
with self.assertRaisesRegex(
TypeError,
r"This might be fixed by applying `jax.lax.pvary\(..., \('i',\)\)` to"
r' the initial'):
f(x, y)
@partial(shard_map, mesh=mesh, in_specs=(P('i'), P()), out_specs=P('i'))
def g(x, y):
def body(carry, _):
c1, c2 = carry
return (c2, c1), ()
y = jax.lax.pvary(y, 'i') # fix the issue
(x_, y_), _ = jax.lax.scan(body, (x, y), (), length=2)
return x_, y_
g(x, y) # doesn't crash
def test_scan_pvary_errors2(self):
mesh = jtu.create_mesh((1, 1), ('i', 'j'))
x = jnp.arange(3.)
y = jnp.arange(3.)
z = jnp.arange(3.)
@partial(shard_map, mesh=mesh, in_specs=(P('i'), P(), P(('i', 'j'))), out_specs=P(('i', 'j')))
def f(x, y, z):
def body(carry, _):
c1, c2, c3 = carry
return (c3, c1, c2), () # swap the carry
# x = jax.lax.pvary(x, 'j')
# y = jax.lax.pvary(y, ('i', 'j'))
carry, _ = jax.lax.scan(body, (x, y, z), (), length=2)
return carry
with self.assertRaisesRegex(
TypeError,
r'This might be fixed by:\n \* applying `jax.lax.pvary\(...,'
r" \('j',\)\)`"):
f(x, y, z)
@partial(shard_map, mesh=mesh, in_specs=(P('i'), P(), P(('i', 'j'))), out_specs=P(('i', 'j')))
def g(x, y, z):
def body(carry, _):
c1, c2, c3 = carry
return (c3, c1, c2), () # swap the carry
x = jax.lax.pvary(x, 'j') # fix the issue
y = jax.lax.pvary(y, ('i', 'j'))
carry, _ = jax.lax.scan(body, (x, y, z), (), length=2)
return carry
g(x, y, z) # doesn't crash
class FunSpec(NamedTuple):
name: str