mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #7082 from zhangqiaorjc:gtsv2fix
PiperOrigin-RevId: 381118896
This commit is contained in:
commit
c985d76ee5
@ -1361,7 +1361,7 @@ tridiagonal_solve_p.def_impl(
|
||||
functools.partial(xla.apply_primitive, tridiagonal_solve_p))
|
||||
tridiagonal_solve_p.def_abstract_eval(lambda dl, d, du, b, *, m, n, ldb, t: b)
|
||||
# TODO(tomhennigan): Consider AD rules using lax.custom_linear_solve?
|
||||
if cusparse is not None:
|
||||
if cusparse is not None and hasattr(cusparse, "gtsv2"):
|
||||
xla.backend_specific_translations['gpu'][tridiagonal_solve_p] = cusparse.gtsv2
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user