mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Match output container to result_shape_dtypes in ffi_call.
Previously, ffi_call would always return a list for multiple results, but if the input `result_shape_dtypes` is a tuple, we should return a tuple. PiperOrigin-RevId: 725834048
This commit is contained in:
parent
fd12f30011
commit
bba09137dc
@ -482,6 +482,8 @@ def ffi_call(
|
||||
attributes=_wrap_kwargs_hashable(kwargs),
|
||||
)
|
||||
if multiple_results:
|
||||
if isinstance(result_shape_dtypes, tuple):
|
||||
return tuple(results)
|
||||
return results
|
||||
else:
|
||||
return results[0]
|
||||
|
Loading…
x
Reference in New Issue
Block a user