mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #7557 from jakevdp:fix-matmat-validation
PiperOrigin-RevId: 389751602
This commit is contained in:
commit
23f91d6909
@ -215,7 +215,7 @@ def _csr_matvec_abstract_eval(data, indices, indptr, v, *, shape, transpose):
|
||||
assert indices.dtype == indptr.dtype
|
||||
assert len(indptr) == shape[0] + 1
|
||||
out_shape = shape[1] if transpose else shape[0]
|
||||
assert v.shape == (shape[0],) if transpose else (shape[1],)
|
||||
assert v.shape[0] == (shape[0] if transpose else shape[1])
|
||||
return core.ShapedArray((out_shape,), data.dtype)
|
||||
|
||||
def _csr_matvec_gpu_translation_rule(c, data, indices, indptr, v, *, shape, transpose):
|
||||
@ -259,6 +259,7 @@ def _csr_matmat_impl(data, indices, indptr, B, *, shape, transpose):
|
||||
|
||||
@csr_matmat_p.def_abstract_eval
|
||||
def _csr_matmat_abstract_eval(data, indices, indptr, B, *, shape, transpose):
|
||||
assert len(shape) == 2
|
||||
assert data.ndim == indices.ndim == indptr.ndim == 1
|
||||
assert B.ndim == 2
|
||||
assert data.shape == indices.shape
|
||||
@ -266,7 +267,7 @@ def _csr_matmat_abstract_eval(data, indices, indptr, B, *, shape, transpose):
|
||||
assert indices.dtype == indptr.dtype
|
||||
assert len(indptr) == shape[0] + 1
|
||||
out_shape = shape[1] if transpose else shape[0]
|
||||
assert B.shape[0] == shape[0] if transpose else shape[1]
|
||||
assert B.shape[0] == (shape[0] if transpose else shape[1])
|
||||
return core.ShapedArray((out_shape, B.shape[1]), data.dtype)
|
||||
|
||||
def _csr_matmat_gpu_translation_rule(c, data, indices, indptr, B, *, shape, transpose):
|
||||
@ -451,7 +452,8 @@ def _coo_matvec_abstract_eval(data, row, col, v, *, shape, transpose):
|
||||
assert data.dtype == v.dtype
|
||||
assert row.dtype == col.dtype
|
||||
assert len(shape) == 2
|
||||
assert v.shape == (shape[0],) if transpose else (shape[1],)
|
||||
assert v.ndim == 1
|
||||
assert v.shape[0] == (shape[0] if transpose else shape[1])
|
||||
out_shape = shape[1] if transpose else shape[0]
|
||||
return core.ShapedArray((out_shape,), data.dtype)
|
||||
|
||||
@ -520,8 +522,9 @@ def _coo_matmat_impl(data, row, col, B, *, shape, transpose):
|
||||
def _coo_matmat_abstract_eval(data, row, col, B, *, shape, transpose):
|
||||
assert data.shape == row.shape == col.shape
|
||||
assert data.dtype == B.dtype
|
||||
assert B.ndim == 2
|
||||
assert len(shape) == 2
|
||||
assert B.shape[0] == shape[0] if transpose else shape[1]
|
||||
assert B.shape[0] == (shape[0] if transpose else shape[1])
|
||||
out_shape = shape[1] if transpose else shape[0]
|
||||
return core.ShapedArray((out_shape, B.shape[1]), data.dtype)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user