mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[XLA:Python] Validate shapes in Python bindings to avoid crashes.
[JAX] Perform LAPACK workspace calculations in int64 to avoid overflows, clamp the values passed to lapack to int32. Will fix https://github.com/google/jax/issues/4358 when incorporated into a jaxlib. PiperOrigin-RevId: 337367394
This commit is contained in:
parent
22c3684d3b
commit
7f4e115a6a
@ -47,6 +47,8 @@ _ops = xla_client.ops
|
||||
|
||||
Shape = xla_client.Shape
|
||||
|
||||
cdef int _int32_max = 0x7FFFFFFF;
|
||||
|
||||
|
||||
cdef register_cpu_custom_call_target(fn_name, void* fn):
|
||||
cdef const char* name = "xla._CUSTOM_CALL_TARGET"
|
||||
@ -964,15 +966,18 @@ def potrf(c, a, lower=False):
|
||||
|
||||
# ?gesdd: Singular value decomposition
|
||||
|
||||
cdef int gesdd_iwork_size(int m, int n) nogil:
|
||||
return 8 * min(m, n)
|
||||
cdef int gesdd_iwork_size(int64_t m, int64_t n) nogil:
|
||||
# Avoid integer overflow; the LAPACK integer type is int32.
|
||||
return min(_int32_max, 8 * min(m, n))
|
||||
|
||||
cdef int cgesdd_rwork_size(int m, int n, int compute_uv) nogil:
|
||||
cdef int mn = min(m, n)
|
||||
cdef int cgesdd_rwork_size(int64_t m, int64_t n, int compute_uv) nogil:
|
||||
cdef int64_t mn = min(m, n)
|
||||
if compute_uv == 0:
|
||||
return 7 * mn
|
||||
cdef int mx = max(m, n)
|
||||
return max(5 * mn * mn + 5 * mn, 2 * mx * mn + 2 * mn * mn + mn)
|
||||
cdef int64_t mx = max(m, n)
|
||||
# Avoid integer overflow; the LAPACK integer type is int32.
|
||||
return min(_int32_max,
|
||||
max(5 * mn * mn + 5 * mn, 2 * mx * mn + 2 * mn * mn + mn))
|
||||
|
||||
cdef char gesdd_jobz(bool_t job_opt_compute_uv,
|
||||
bool_t job_opt_full_matrices) nogil:
|
||||
@ -1286,11 +1291,13 @@ def gesdd(c, a, full_matrices=True, compute_uv=True):
|
||||
# syevd: Symmetric eigendecomposition
|
||||
|
||||
# Workspace sizes, taken from the LAPACK documentation.
|
||||
cdef int syevd_work_size(int n) nogil:
|
||||
return 1 + 6 * n + 2 * n * n
|
||||
cdef int syevd_work_size(int64_t n) nogil:
|
||||
# Avoids int32 overflow.
|
||||
return min(_int32_max, 1 + 6 * n + 2 * n * n)
|
||||
|
||||
cdef int syevd_iwork_size(int n) nogil:
|
||||
return 3 + 5 * n
|
||||
cdef int syevd_iwork_size(int64_t n) nogil:
|
||||
# Avoids int32 overflow.
|
||||
return min(_int32_max, 3 + 5 * n)
|
||||
|
||||
cdef void lapack_ssyevd(void* out_tuple, void** data) nogil:
|
||||
cdef int32_t lower = (<int32_t*>(data[0]))[0]
|
||||
@ -1352,11 +1359,13 @@ cdef void lapack_dsyevd(void* out_tuple, void** data) nogil:
|
||||
register_cpu_custom_call_target(b"lapack_dsyevd", <void*>(lapack_dsyevd))
|
||||
|
||||
# Workspace sizes, taken from the LAPACK documentation.
|
||||
cdef int heevd_work_size(int n) nogil:
|
||||
return 1 + 2 * n + n * n
|
||||
cdef int heevd_work_size(int64_t n) nogil:
|
||||
# Avoid int32 overflow.
|
||||
return min(_int32_max, 1 + 2 * n + n * n)
|
||||
|
||||
cdef int heevd_rwork_size(int n) nogil:
|
||||
return 1 + 5 * n + 2 * n * n
|
||||
cdef int heevd_rwork_size(int64_t n) nogil:
|
||||
# Avoid int32 overflow.
|
||||
return min(_int32_max, 1 + 5 * n + 2 * n * n)
|
||||
|
||||
|
||||
cdef void lapack_cheevd(void* out_tuple, void** data) nogil:
|
||||
|
Loading…
x
Reference in New Issue
Block a user