Merge pull request #7557 from jakevdp:fix-matmat-validation

PiperOrigin-RevId: 389751602
This commit is contained in:
jax authors 2021-08-09 16:31:28 -07:00
commit 23f91d6909

View File

@ -215,7 +215,7 @@ def _csr_matvec_abstract_eval(data, indices, indptr, v, *, shape, transpose):
assert indices.dtype == indptr.dtype
assert len(indptr) == shape[0] + 1
out_shape = shape[1] if transpose else shape[0]
assert v.shape == (shape[0],) if transpose else (shape[1],)
assert v.shape[0] == (shape[0] if transpose else shape[1])
return core.ShapedArray((out_shape,), data.dtype)
def _csr_matvec_gpu_translation_rule(c, data, indices, indptr, v, *, shape, transpose):
@ -259,6 +259,7 @@ def _csr_matmat_impl(data, indices, indptr, B, *, shape, transpose):
@csr_matmat_p.def_abstract_eval
def _csr_matmat_abstract_eval(data, indices, indptr, B, *, shape, transpose):
assert len(shape) == 2
assert data.ndim == indices.ndim == indptr.ndim == 1
assert B.ndim == 2
assert data.shape == indices.shape
@ -266,7 +267,7 @@ def _csr_matmat_abstract_eval(data, indices, indptr, B, *, shape, transpose):
assert indices.dtype == indptr.dtype
assert len(indptr) == shape[0] + 1
out_shape = shape[1] if transpose else shape[0]
assert B.shape[0] == shape[0] if transpose else shape[1]
assert B.shape[0] == (shape[0] if transpose else shape[1])
return core.ShapedArray((out_shape, B.shape[1]), data.dtype)
def _csr_matmat_gpu_translation_rule(c, data, indices, indptr, B, *, shape, transpose):
@ -451,7 +452,8 @@ def _coo_matvec_abstract_eval(data, row, col, v, *, shape, transpose):
assert data.dtype == v.dtype
assert row.dtype == col.dtype
assert len(shape) == 2
assert v.shape == (shape[0],) if transpose else (shape[1],)
assert v.ndim == 1
assert v.shape[0] == (shape[0] if transpose else shape[1])
out_shape = shape[1] if transpose else shape[0]
return core.ShapedArray((out_shape,), data.dtype)
@ -520,8 +522,9 @@ def _coo_matmat_impl(data, row, col, B, *, shape, transpose):
def _coo_matmat_abstract_eval(data, row, col, B, *, shape, transpose):
assert data.shape == row.shape == col.shape
assert data.dtype == B.dtype
assert B.ndim == 2
assert len(shape) == 2
assert B.shape[0] == shape[0] if transpose else shape[1]
assert B.shape[0] == (shape[0] if transpose else shape[1])
out_shape = shape[1] if transpose else shape[0]
return core.ShapedArray((out_shape, B.shape[1]), data.dtype)