Fixed _mapped_axis_size raising an uncaught TypeError

This commit is contained in:
treyra 2023-07-06 18:51:44 -07:00
parent f08e52faef
commit 9ec6ebb7e0

View File

@ -1275,6 +1275,11 @@ def _mapped_axis_size(fn, tree, vals, dims, name):
msg = f"{name} must have at least one non-None value in in_axes"
raise ValueError(msg)
def _get_argument_type(x):
try:
return shaped_abstractify(x).str_short()
except TypeError: #Catch all for user specified objects that can't be interpreted as a data type
return "unknown"
msg = [f"{name} got inconsistent sizes for array axes to be mapped:\n"]
args, kwargs = tree_unflatten(tree, vals)
try:
@ -1283,15 +1288,15 @@ def _mapped_axis_size(fn, tree, vals, dims, name):
ba = None
if ba is None:
args_paths = [f'args{keystr(p)} '
f'of type {shaped_abstractify(x).str_short()}'
f'of type {_get_argument_type(x)}'
for p, x in generate_key_paths(args)]
kwargs_paths = [f'kwargs{keystr(p)} '
f'of type {shaped_abstractify(x).str_short()}'
f'of type {_get_argument_type(x)}'
for p, x in generate_key_paths(kwargs)]
key_paths = [*args_paths, *kwargs_paths]
else:
key_paths = [f'argument {name}{keystr(p)} '
f'of type {shaped_abstractify(x).str_short()}'
f'of type {_get_argument_type(x)}'
for name, arg in ba.arguments.items()
for p, x in generate_key_paths(arg)]
all_sizes = [_get_axis_size(name, np.shape(x), d) if d is not None else None