Merge pull request #7082 from zhangqiaorjc:gtsv2fix

PiperOrigin-RevId: 381118896
This commit is contained in:
jax authors 2021-06-23 14:44:41 -07:00
commit c985d76ee5

View File

@ -1361,7 +1361,7 @@ tridiagonal_solve_p.def_impl(
functools.partial(xla.apply_primitive, tridiagonal_solve_p))
tridiagonal_solve_p.def_abstract_eval(lambda dl, d, du, b, *, m, n, ldb, t: b)
# TODO(tomhennigan): Consider AD rules using lax.custom_linear_solve?
if cusparse is not None:
if cusparse is not None and hasattr(cusparse, "gtsv2"):
xla.backend_specific_translations['gpu'][tridiagonal_solve_p] = cusparse.gtsv2