[XLA:Python] Validate shapes in Python bindings to avoid crashes.

[JAX] Perform LAPACK workspace calculations in int64 to avoid overflows, clamp the values passed to lapack to int32.

Will fix https://github.com/google/jax/issues/4358 when incorporated into a jaxlib.

PiperOrigin-RevId: 337367394
This commit is contained in:
Peter Hawkins 2020-10-15 13:09:37 -07:00 committed by jax authors
parent 22c3684d3b
commit 7f4e115a6a

View File

@ -47,6 +47,8 @@ _ops = xla_client.ops
Shape = xla_client.Shape
cdef int _int32_max = 0x7FFFFFFF;
cdef register_cpu_custom_call_target(fn_name, void* fn):
cdef const char* name = "xla._CUSTOM_CALL_TARGET"
@ -964,15 +966,18 @@ def potrf(c, a, lower=False):
# ?gesdd: Singular value decomposition
cdef int gesdd_iwork_size(int m, int n) nogil:
return 8 * min(m, n)
cdef int gesdd_iwork_size(int64_t m, int64_t n) nogil:
# Avoid integer overflow; the LAPACK integer type is int32.
return min(_int32_max, 8 * min(m, n))
cdef int cgesdd_rwork_size(int m, int n, int compute_uv) nogil:
cdef int mn = min(m, n)
cdef int cgesdd_rwork_size(int64_t m, int64_t n, int compute_uv) nogil:
cdef int64_t mn = min(m, n)
if compute_uv == 0:
return 7 * mn
cdef int mx = max(m, n)
return max(5 * mn * mn + 5 * mn, 2 * mx * mn + 2 * mn * mn + mn)
cdef int64_t mx = max(m, n)
# Avoid integer overflow; the LAPACK integer type is int32.
return min(_int32_max,
max(5 * mn * mn + 5 * mn, 2 * mx * mn + 2 * mn * mn + mn))
cdef char gesdd_jobz(bool_t job_opt_compute_uv,
bool_t job_opt_full_matrices) nogil:
@ -1286,11 +1291,13 @@ def gesdd(c, a, full_matrices=True, compute_uv=True):
# syevd: Symmetric eigendecomposition
# Workspace sizes, taken from the LAPACK documentation.
cdef int syevd_work_size(int n) nogil:
return 1 + 6 * n + 2 * n * n
cdef int syevd_work_size(int64_t n) nogil:
# Avoids int32 overflow.
return min(_int32_max, 1 + 6 * n + 2 * n * n)
cdef int syevd_iwork_size(int n) nogil:
return 3 + 5 * n
cdef int syevd_iwork_size(int64_t n) nogil:
# Avoids int32 overflow.
return min(_int32_max, 3 + 5 * n)
cdef void lapack_ssyevd(void* out_tuple, void** data) nogil:
cdef int32_t lower = (<int32_t*>(data[0]))[0]
@ -1352,11 +1359,13 @@ cdef void lapack_dsyevd(void* out_tuple, void** data) nogil:
register_cpu_custom_call_target(b"lapack_dsyevd", <void*>(lapack_dsyevd))
# Workspace sizes, taken from the LAPACK documentation.
cdef int heevd_work_size(int n) nogil:
return 1 + 2 * n + n * n
cdef int heevd_work_size(int64_t n) nogil:
# Avoid int32 overflow.
return min(_int32_max, 1 + 2 * n + n * n)
cdef int heevd_rwork_size(int n) nogil:
return 1 + 5 * n + 2 * n * n
cdef int heevd_rwork_size(int64_t n) nogil:
# Avoid int32 overflow.
return min(_int32_max, 1 + 5 * n + 2 * n * n)
cdef void lapack_cheevd(void* out_tuple, void** data) nogil: