mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Test for gtsv2 attr on cusparse.
This commit is contained in:
parent
9493b315bc
commit
7cc277d634
@ -1361,7 +1361,7 @@ tridiagonal_solve_p.def_impl(
|
||||
functools.partial(xla.apply_primitive, tridiagonal_solve_p))
|
||||
tridiagonal_solve_p.def_abstract_eval(lambda dl, d, du, b, *, m, n, ldb, t: b)
|
||||
# TODO(tomhennigan): Consider AD rules using lax.custom_linear_solve?
|
||||
if cusparse is not None:
|
||||
if cusparse is not None and hasattr(cusparse, "gtsv2"):
|
||||
xla.backend_specific_translations['gpu'][tridiagonal_solve_p] = cusparse.gtsv2
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user