Match output container to result_shape_dtypes in ffi_call.

Previously, ffi_call would always return a list for multiple results, but if the input `result_shape_dtypes` is a tuple, we should return a tuple.

PiperOrigin-RevId: 725834048
This commit is contained in:
Dan Foreman-Mackey 2025-02-11 17:32:48 -08:00 committed by jax authors
parent fd12f30011
commit bba09137dc

View File

@ -482,6 +482,8 @@ def ffi_call(
attributes=_wrap_kwargs_hashable(kwargs),
)
if multiple_results:
if isinstance(result_shape_dtypes, tuple):
return tuple(results)
return results
else:
return results[0]