mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Increase threshold for switching to unbatched triangular solve on GPU
This commit is contained in:
parent
894873e317
commit
809d689582
@ -707,7 +707,7 @@ def _triangular_solve_gpu_translation_rule(trsm_impl,
|
||||
if conjugate_a and not transpose_a:
|
||||
a = xops.Conj(a)
|
||||
conjugate_a = False
|
||||
if batch > 1 and m <= 32 and n <= 32:
|
||||
if batch > 1 and m <= 256 and n <= 256:
|
||||
return trsm_impl(
|
||||
c, a, b, left_side, lower, transpose_a,
|
||||
conjugate_a, unit_diagonal)
|
||||
|
Loading…
x
Reference in New Issue
Block a user