mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Fixed _mapped_axis_size raising an uncaught TypeError
This commit is contained in:
parent
f08e52faef
commit
9ec6ebb7e0
@ -1275,6 +1275,11 @@ def _mapped_axis_size(fn, tree, vals, dims, name):
|
||||
msg = f"{name} must have at least one non-None value in in_axes"
|
||||
raise ValueError(msg)
|
||||
|
||||
def _get_argument_type(x):
|
||||
try:
|
||||
return shaped_abstractify(x).str_short()
|
||||
except TypeError: #Catch all for user specified objects that can't be interpreted as a data type
|
||||
return "unknown"
|
||||
msg = [f"{name} got inconsistent sizes for array axes to be mapped:\n"]
|
||||
args, kwargs = tree_unflatten(tree, vals)
|
||||
try:
|
||||
@ -1283,15 +1288,15 @@ def _mapped_axis_size(fn, tree, vals, dims, name):
|
||||
ba = None
|
||||
if ba is None:
|
||||
args_paths = [f'args{keystr(p)} '
|
||||
f'of type {shaped_abstractify(x).str_short()}'
|
||||
f'of type {_get_argument_type(x)}'
|
||||
for p, x in generate_key_paths(args)]
|
||||
kwargs_paths = [f'kwargs{keystr(p)} '
|
||||
f'of type {shaped_abstractify(x).str_short()}'
|
||||
f'of type {_get_argument_type(x)}'
|
||||
for p, x in generate_key_paths(kwargs)]
|
||||
key_paths = [*args_paths, *kwargs_paths]
|
||||
else:
|
||||
key_paths = [f'argument {name}{keystr(p)} '
|
||||
f'of type {shaped_abstractify(x).str_short()}'
|
||||
f'of type {_get_argument_type(x)}'
|
||||
for name, arg in ba.arguments.items()
|
||||
for p, x in generate_key_paths(arg)]
|
||||
all_sizes = [_get_axis_size(name, np.shape(x), d) if d is not None else None
|
||||
|
Loading…
x
Reference in New Issue
Block a user