Be assertive that we support unordered sparse tensors.
PiperOrigin-RevId: 360750129 Change-Id: If44b18ce8894a0608069dc09ed705aef0433942f
This commit is contained in:
parent
07ac149508
commit
3a3a7631fa
@ -163,6 +163,18 @@ class SparseTensorDenseMatMulTest(test.TestCase):
|
|||||||
sparse_ops.sparse_tensor_dense_matmul(
|
sparse_ops.sparse_tensor_dense_matmul(
|
||||||
sparse_t, dense_t, adjoint_a=True))
|
sparse_t, dense_t, adjoint_a=True))
|
||||||
|
|
||||||
|
def testUnorderedIndicesForSparseTensorDenseMatmul(self):
|
||||||
|
indices = np.array([(2, 1), (0, 0)]).astype(np.int64)
|
||||||
|
values = np.array([10, 11]).astype(np.float32)
|
||||||
|
shape = [3, 2]
|
||||||
|
sparse_t = sparse_tensor.SparseTensor(indices, values, shape)
|
||||||
|
|
||||||
|
dense_t = np.array([[1] * 500, [2] * 500], dtype=np.float32)
|
||||||
|
expected_t = np.array([[11] * 500, [0] * 500, [20] * 500], dtype=np.float32)
|
||||||
|
|
||||||
|
self.assertAllClose(
|
||||||
|
expected_t, sparse_ops.sparse_tensor_dense_matmul(sparse_t, dense_t))
|
||||||
|
|
||||||
@test_util.run_gpu_only
|
@test_util.run_gpu_only
|
||||||
def testInvalidIndicesForSparseTensorDenseMatmulOnGPU(self):
|
def testInvalidIndicesForSparseTensorDenseMatmulOnGPU(self):
|
||||||
indices = np.array([[1, 10]]).astype(np.int64)
|
indices = np.array([[1, 10]]).astype(np.int64)
|
||||||
|
@ -2431,14 +2431,31 @@ def sparse_tensor_dense_matmul(sp_a,
|
|||||||
(or SparseTensor) "B". Please note that one and only one of the inputs MUST
|
(or SparseTensor) "B". Please note that one and only one of the inputs MUST
|
||||||
be a SparseTensor and the other MUST be a dense matrix.
|
be a SparseTensor and the other MUST be a dense matrix.
|
||||||
|
|
||||||
No validity checking is performed on the indices of `A`. However, the
|
The following input format is recommended (but not required) for optimal
|
||||||
following input format is recommended for optimal behavior:
|
performance:
|
||||||
|
|
||||||
* If `adjoint_a == false`: `A` should be sorted in lexicographically
|
* If `adjoint_a == false`: `A` should be sorted in lexicographically
|
||||||
increasing order. Use `sparse.reorder` if you're not sure.
|
increasing order. Use `sparse.reorder` if you're not sure.
|
||||||
* If `adjoint_a == true`: `A` should be sorted in order of increasing
|
* If `adjoint_a == true`: `A` should be sorted in order of increasing
|
||||||
dimension 1 (i.e., "column major" order instead of "row major" order).
|
dimension 1 (i.e., "column major" order instead of "row major" order).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sp_a: SparseTensor (or dense Matrix) A, of rank 2.
|
||||||
|
b: dense Matrix (or SparseTensor) B, with the same dtype as sp_a.
|
||||||
|
adjoint_a: Use the adjoint of A in the matrix multiply. If A is complex,
|
||||||
|
this is transpose(conj(A)). Otherwise it's transpose(A).
|
||||||
|
adjoint_b: Use the adjoint of B in the matrix multiply. If B is complex,
|
||||||
|
this is transpose(conj(B)). Otherwise it's transpose(B).
|
||||||
|
name: A name prefix for the returned tensors (optional)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dense matrix (pseudo-code in dense np.matrix notation):
|
||||||
|
`A = A.H if adjoint_a else A`
|
||||||
|
`B = B.H if adjoint_b else B`
|
||||||
|
`return A*B`
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
|
||||||
Using `tf.nn.embedding_lookup_sparse` for sparse multiplication:
|
Using `tf.nn.embedding_lookup_sparse` for sparse multiplication:
|
||||||
|
|
||||||
It's not obvious but you can consider `embedding_lookup_sparse` as another
|
It's not obvious but you can consider `embedding_lookup_sparse` as another
|
||||||
@ -2610,20 +2627,6 @@ def sparse_tensor_dense_matmul(sp_a,
|
|||||||
0.8 25 False 1000 1000 0.00211448 0.00752736 3.55992
|
0.8 25 False 1000 1000 0.00211448 0.00752736 3.55992
|
||||||
```
|
```
|
||||||
|
|
||||||
Args:
|
|
||||||
sp_a: SparseTensor (or dense Matrix) A, of rank 2.
|
|
||||||
b: dense Matrix (or SparseTensor) B, with the same dtype as sp_a.
|
|
||||||
adjoint_a: Use the adjoint of A in the matrix multiply. If A is complex,
|
|
||||||
this is transpose(conj(A)). Otherwise it's transpose(A).
|
|
||||||
adjoint_b: Use the adjoint of B in the matrix multiply. If B is complex,
|
|
||||||
this is transpose(conj(B)). Otherwise it's transpose(B).
|
|
||||||
name: A name prefix for the returned tensors (optional)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A dense matrix (pseudo-code in dense np.matrix notation):
|
|
||||||
`A = A.H if adjoint_a else A`
|
|
||||||
`B = B.H if adjoint_b else B`
|
|
||||||
`return A*B`
|
|
||||||
"""
|
"""
|
||||||
# pylint: enable=line-too-long
|
# pylint: enable=line-too-long
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user