Be assertive that we support unordered sparse tensors.
PiperOrigin-RevId: 360750129 Change-Id: If44b18ce8894a0608069dc09ed705aef0433942f
This commit is contained in:
parent
07ac149508
commit
3a3a7631fa
@ -163,6 +163,18 @@ class SparseTensorDenseMatMulTest(test.TestCase):
|
||||
sparse_ops.sparse_tensor_dense_matmul(
|
||||
sparse_t, dense_t, adjoint_a=True))
|
||||
|
||||
def testUnorderedIndicesForSparseTensorDenseMatmul(self):
|
||||
indices = np.array([(2, 1), (0, 0)]).astype(np.int64)
|
||||
values = np.array([10, 11]).astype(np.float32)
|
||||
shape = [3, 2]
|
||||
sparse_t = sparse_tensor.SparseTensor(indices, values, shape)
|
||||
|
||||
dense_t = np.array([[1] * 500, [2] * 500], dtype=np.float32)
|
||||
expected_t = np.array([[11] * 500, [0] * 500, [20] * 500], dtype=np.float32)
|
||||
|
||||
self.assertAllClose(
|
||||
expected_t, sparse_ops.sparse_tensor_dense_matmul(sparse_t, dense_t))
|
||||
|
||||
@test_util.run_gpu_only
|
||||
def testInvalidIndicesForSparseTensorDenseMatmulOnGPU(self):
|
||||
indices = np.array([[1, 10]]).astype(np.int64)
|
||||
|
@ -2431,14 +2431,31 @@ def sparse_tensor_dense_matmul(sp_a,
|
||||
(or SparseTensor) "B". Please note that one and only one of the inputs MUST
|
||||
be a SparseTensor and the other MUST be a dense matrix.
|
||||
|
||||
No validity checking is performed on the indices of `A`. However, the
|
||||
following input format is recommended for optimal behavior:
|
||||
The following input format is recommended (but not required) for optimal
|
||||
performance:
|
||||
|
||||
* If `adjoint_a == false`: `A` should be sorted in lexicographically
|
||||
increasing order. Use `sparse.reorder` if you're not sure.
|
||||
* If `adjoint_a == true`: `A` should be sorted in order of increasing
|
||||
dimension 1 (i.e., "column major" order instead of "row major" order).
|
||||
|
||||
Args:
|
||||
sp_a: SparseTensor (or dense Matrix) A, of rank 2.
|
||||
b: dense Matrix (or SparseTensor) B, with the same dtype as sp_a.
|
||||
adjoint_a: Use the adjoint of A in the matrix multiply. If A is complex,
|
||||
this is transpose(conj(A)). Otherwise it's transpose(A).
|
||||
adjoint_b: Use the adjoint of B in the matrix multiply. If B is complex,
|
||||
this is transpose(conj(B)). Otherwise it's transpose(B).
|
||||
name: A name prefix for the returned tensors (optional)
|
||||
|
||||
Returns:
|
||||
A dense matrix (pseudo-code in dense np.matrix notation):
|
||||
`A = A.H if adjoint_a else A`
|
||||
`B = B.H if adjoint_b else B`
|
||||
`return A*B`
|
||||
|
||||
Notes:
|
||||
|
||||
Using `tf.nn.embedding_lookup_sparse` for sparse multiplication:
|
||||
|
||||
It's not obvious but you can consider `embedding_lookup_sparse` as another
|
||||
@ -2610,20 +2627,6 @@ def sparse_tensor_dense_matmul(sp_a,
|
||||
0.8 25 False 1000 1000 0.00211448 0.00752736 3.55992
|
||||
```
|
||||
|
||||
Args:
|
||||
sp_a: SparseTensor (or dense Matrix) A, of rank 2.
|
||||
b: dense Matrix (or SparseTensor) B, with the same dtype as sp_a.
|
||||
adjoint_a: Use the adjoint of A in the matrix multiply. If A is complex,
|
||||
this is transpose(conj(A)). Otherwise it's transpose(A).
|
||||
adjoint_b: Use the adjoint of B in the matrix multiply. If B is complex,
|
||||
this is transpose(conj(B)). Otherwise it's transpose(B).
|
||||
name: A name prefix for the returned tensors (optional)
|
||||
|
||||
Returns:
|
||||
A dense matrix (pseudo-code in dense np.matrix notation):
|
||||
`A = A.H if adjoint_a else A`
|
||||
`B = B.H if adjoint_b else B`
|
||||
`return A*B`
|
||||
"""
|
||||
# pylint: enable=line-too-long
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user