Increase threshold for switching to unbatched triangular solve on GPU

This commit is contained in:
David Pfau 2021-03-22 21:29:39 +00:00
parent 894873e317
commit 809d689582

View File

@ -707,7 +707,7 @@ def _triangular_solve_gpu_translation_rule(trsm_impl,
if conjugate_a and not transpose_a:
a = xops.Conj(a)
conjugate_a = False
if batch > 1 and m <= 32 and n <= 32:
if batch > 1 and m <= 256 and n <= 256:
return trsm_impl(
c, a, b, left_side, lower, transpose_a,
conjugate_a, unit_diagonal)