mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Improved extern selection in Pallas GPU
Previously, * weakly typed avals matched the wrong externs; * this was addressed by #23193, which disallowed weakly typed avals entirely. Here we check if a weakly typed aval can be casted to the extern input dtype when selecting an extern. PiperOrigin-RevId: 669378582
This commit is contained in:
parent
f8a46629c2
commit
fb7fa2a09e
@ -580,6 +580,7 @@ class _Extern:
|
||||
return False
|
||||
return all(
|
||||
aval.dtype == jnp.dtype(arg_type)
|
||||
or (aval.weak_type and aval.dtype.kind == jnp.dtype(arg_type).kind)
|
||||
for aval, arg_type in zip(avals, self.arg_types)
|
||||
)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user