Merge pull request #34624 from noble-ai:dense_sparse_matmul
PiperOrigin-RevId: 290179646 Change-Id: I1bd945d88215e3900005c3c70b4525ad3707cf91
This commit is contained in:
commit
f0cf1dc02a
@ -2197,7 +2197,10 @@ def sparse_tensor_dense_matmul(sp_a,
|
||||
adjoint_b=False,
|
||||
name=None):
|
||||
# pylint: disable=line-too-long
|
||||
"""Multiply SparseTensor (of rank 2) "A" by dense matrix "B".
|
||||
"""Multiply SparseTensor (or dense Matrix) (of rank 2) "A" by dense matrix
|
||||
|
||||
(or SparseTensor) "B". Please note that one and only one of the inputs MUST
|
||||
be a SparseTensor and the other MUST be a dense matrix.
|
||||
|
||||
No validity checking is performed on the indices of `A`. However, the
|
||||
following input format is recommended for optimal behavior:
|
||||
@ -2379,8 +2382,8 @@ def sparse_tensor_dense_matmul(sp_a,
|
||||
```
|
||||
|
||||
Args:
|
||||
sp_a: SparseTensor A, of rank 2.
|
||||
b: A dense Matrix with the same dtype as sp_a.
|
||||
sp_a: SparseTensor (or dense Matrix) A, of rank 2.
|
||||
b: dense Matrix (or SparseTensor) B, with the same dtype as sp_a.
|
||||
adjoint_a: Use the adjoint of A in the matrix multiply. If A is complex,
|
||||
this is transpose(conj(A)). Otherwise it's transpose(A).
|
||||
adjoint_b: Use the adjoint of B in the matrix multiply. If B is complex,
|
||||
@ -2394,17 +2397,32 @@ def sparse_tensor_dense_matmul(sp_a,
|
||||
`return A*B`
|
||||
"""
|
||||
# pylint: enable=line-too-long
|
||||
sp_a = _convert_to_sparse_tensor(sp_a)
|
||||
with ops.name_scope(name, "SparseTensorDenseMatMul",
|
||||
[sp_a.indices, sp_a.values, b]) as name:
|
||||
b = ops.convert_to_tensor(b, name="b")
|
||||
return gen_sparse_ops.sparse_tensor_dense_mat_mul(
|
||||
a_indices=sp_a.indices,
|
||||
a_values=sp_a.values,
|
||||
a_shape=sp_a.dense_shape,
|
||||
b=b,
|
||||
adjoint_a=adjoint_a,
|
||||
adjoint_b=adjoint_b)
|
||||
|
||||
if isinstance(b, sparse_tensor.SparseTensor) \
|
||||
or isinstance(b, sparse_tensor.SparseTensorValue):
|
||||
# We can do C * D where C is sparse but if we want to do A * B when
|
||||
# B is sparse we have to transpose. But AB = (B'A')' so we have to feed in
|
||||
# the transpose of the arguments as well.
|
||||
if adjoint_a != adjoint_b:
|
||||
return array_ops.transpose(
|
||||
sparse_tensor_dense_matmul(b, sp_a, adjoint_a, adjoint_b))
|
||||
else:
|
||||
return array_ops.transpose(
|
||||
sparse_tensor_dense_matmul(
|
||||
b, sp_a, adjoint_a=not adjoint_a, adjoint_b=not adjoint_b))
|
||||
|
||||
else:
|
||||
sp_a = _convert_to_sparse_tensor(sp_a)
|
||||
with ops.name_scope(name, "SparseTensorDenseMatMul",
|
||||
[sp_a.indices, sp_a.values, b]) as name:
|
||||
b = ops.convert_to_tensor(b, name="b")
|
||||
return gen_sparse_ops.sparse_tensor_dense_mat_mul(
|
||||
a_indices=sp_a.indices,
|
||||
a_values=sp_a.values,
|
||||
a_shape=sp_a.dense_shape,
|
||||
b=b,
|
||||
adjoint_a=adjoint_a,
|
||||
adjoint_b=adjoint_b)
|
||||
|
||||
|
||||
@tf_export("sparse.softmax", v1=["sparse.softmax", "sparse_softmax"])
|
||||
|
@ -28,6 +28,7 @@ from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import test_util
|
||||
# Need array_grad to register gradient for Identity.
|
||||
from tensorflow.python.ops import array_grad # pylint: disable=unused-import
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gradient_checker_v2 as gradient_checker
|
||||
from tensorflow.python.ops import math_ops
|
||||
# Need sparse_grad to register gradient for SparseToDense.
|
||||
@ -143,6 +144,42 @@ class SparseOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
result_dense = self.evaluate(dense)
|
||||
self.assertAllEqual(expected_dense, result_dense)
|
||||
|
||||
def testDenseSparseTensorMatMul(self):
|
||||
|
||||
np.random.seed(42)
|
||||
dense_numpy_array = np.random.rand(3, 3)
|
||||
independent_dense_tf = constant_op.constant(
|
||||
dense_numpy_array, dtype='float32')
|
||||
|
||||
sp = sparse_tensor.SparseTensor(
|
||||
indices=[[0, 0], [1, 2]], values=[4., 8.], dense_shape=[3, 3])
|
||||
dense_of_sparse = sparse_ops.sparse_to_dense(sp.indices, sp.shape,
|
||||
sp.values)
|
||||
|
||||
result = sparse_ops.sparse_tensor_dense_matmul(
|
||||
independent_dense_tf, sp, adjoint_a=False, adjoint_b=False)
|
||||
expected = math_ops.matmul(independent_dense_tf, dense_of_sparse)
|
||||
self.assertAllEqual(expected, result)
|
||||
|
||||
result = sparse_ops.sparse_tensor_dense_matmul(
|
||||
independent_dense_tf, sp, adjoint_a=False, adjoint_b=True)
|
||||
expected = math_ops.matmul(independent_dense_tf,
|
||||
array_ops.transpose(dense_of_sparse))
|
||||
self.assertAllEqual(expected, result)
|
||||
|
||||
result = sparse_ops.sparse_tensor_dense_matmul(
|
||||
independent_dense_tf, sp, adjoint_a=True, adjoint_b=False)
|
||||
expected = math_ops.matmul(
|
||||
array_ops.transpose(independent_dense_tf), dense_of_sparse)
|
||||
self.assertAllEqual(expected, result)
|
||||
|
||||
result = sparse_ops.sparse_tensor_dense_matmul(
|
||||
independent_dense_tf, sp, adjoint_a=True, adjoint_b=True)
|
||||
expected = math_ops.matmul(
|
||||
array_ops.transpose(independent_dense_tf),
|
||||
array_ops.transpose(dense_of_sparse))
|
||||
self.assertAllEqual(expected, result)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
googletest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user