Improved extern selection in Pallas GPU

Previously,

* weakly typed avals matched the wrong externs;
* this was addressed by #23193, which disallowed weakly typed avals entirely.

Here we check if a weakly typed aval can be casted to the extern input dtype
when selecting an extern.

PiperOrigin-RevId: 669378582
This commit is contained in:
Sergei Lebedev 2024-08-30 10:52:58 -07:00 committed by jax authors
parent f8a46629c2
commit fb7fa2a09e

View File

@ -580,6 +580,7 @@ class _Extern:
return False
return all(
aval.dtype == jnp.dtype(arg_type)
or (aval.weak_type and aval.dtype.kind == jnp.dtype(arg_type).kind)
for aval, arg_type in zip(avals, self.arg_types)
)