From 7cc277d63494f85a9130b954a356847f70630b8e Mon Sep 17 00:00:00 2001 From: Qiao Zhang Date: Wed, 23 Jun 2021 14:22:08 -0700 Subject: [PATCH] Test for gtsv2 attr on cusparse. --- jax/_src/lax/linalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index 84850d49c..052dd567c 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -1361,7 +1361,7 @@ tridiagonal_solve_p.def_impl( functools.partial(xla.apply_primitive, tridiagonal_solve_p)) tridiagonal_solve_p.def_abstract_eval(lambda dl, d, du, b, *, m, n, ldb, t: b) # TODO(tomhennigan): Consider AD rules using lax.custom_linear_solve? -if cusparse is not None: +if cusparse is not None and hasattr(cusparse, "gtsv2"): xla.backend_specific_translations['gpu'][tridiagonal_solve_p] = cusparse.gtsv2