Add CSRSparseMatrix ops.
PiperOrigin-RevId: 267029468
This commit is contained in:
parent
cb49525a25
commit
57395f70db
@ -1170,6 +1170,7 @@ tf_gen_op_libs(
|
|||||||
"set_ops",
|
"set_ops",
|
||||||
"script_ops",
|
"script_ops",
|
||||||
"sendrecv_ops",
|
"sendrecv_ops",
|
||||||
|
"sparse_csr_matrix_ops",
|
||||||
"sparse_ops",
|
"sparse_ops",
|
||||||
"spectral_ops",
|
"spectral_ops",
|
||||||
"state_ops",
|
"state_ops",
|
||||||
@ -1395,6 +1396,7 @@ cc_library(
|
|||||||
":sdca_ops_op_lib",
|
":sdca_ops_op_lib",
|
||||||
":sendrecv_ops_op_lib",
|
":sendrecv_ops_op_lib",
|
||||||
":set_ops_op_lib",
|
":set_ops_op_lib",
|
||||||
|
":sparse_csr_matrix_ops_op_lib",
|
||||||
":sparse_ops_op_lib",
|
":sparse_ops_op_lib",
|
||||||
":summary_ops_op_lib",
|
":summary_ops_op_lib",
|
||||||
":spectral_ops_op_lib",
|
":spectral_ops_op_lib",
|
||||||
@ -1584,6 +1586,7 @@ cc_library(
|
|||||||
"//tensorflow/core/kernels:summary_kernels",
|
"//tensorflow/core/kernels:summary_kernels",
|
||||||
"//tensorflow/core/kernels:training_ops",
|
"//tensorflow/core/kernels:training_ops",
|
||||||
"//tensorflow/core/kernels:word2vec_kernels",
|
"//tensorflow/core/kernels:word2vec_kernels",
|
||||||
|
"//tensorflow/core/kernels/sparse:kernels",
|
||||||
] + tf_additional_cloud_kernel_deps() + if_not_windows([
|
] + tf_additional_cloud_kernel_deps() + if_not_windows([
|
||||||
"//tensorflow/core/kernels:fact_op",
|
"//tensorflow/core/kernels:fact_op",
|
||||||
"//tensorflow/core/kernels:array_not_windows",
|
"//tensorflow/core/kernels:array_not_windows",
|
||||||
@ -5179,6 +5182,7 @@ tf_cc_tests(
|
|||||||
"ops/rnn_ops_test.cc",
|
"ops/rnn_ops_test.cc",
|
||||||
"ops/set_ops_test.cc",
|
"ops/set_ops_test.cc",
|
||||||
"ops/shape_function_test.cc",
|
"ops/shape_function_test.cc",
|
||||||
|
"ops/sparse_csr_matrix_ops_test.cc",
|
||||||
"ops/sparse_ops_test.cc",
|
"ops/sparse_ops_test.cc",
|
||||||
"ops/spectral_ops_test.cc",
|
"ops/spectral_ops_test.cc",
|
||||||
"ops/state_ops_test.cc",
|
"ops/state_ops_test.cc",
|
||||||
|
@ -0,0 +1,29 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "CSRSparseMatrixComponents"
|
||||||
|
visibility: HIDDEN
|
||||||
|
in_arg {
|
||||||
|
name: "csr_sparse_matrix"
|
||||||
|
description: "A batched CSRSparseMatrix."
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "index"
|
||||||
|
description: "The index in `csr_sparse_matrix`'s batch."
|
||||||
|
}
|
||||||
|
out_arg {
|
||||||
|
name: "row_ptrs"
|
||||||
|
description: "An array containing CSR matrix row pointers."
|
||||||
|
}
|
||||||
|
out_arg {
|
||||||
|
name: "col_inds"
|
||||||
|
description: "An array containing CSR matrix column indices."
|
||||||
|
}
|
||||||
|
out_arg {
|
||||||
|
name: "values"
|
||||||
|
description: "An array containing CSR matrix nonzero values."
|
||||||
|
}
|
||||||
|
summary: "Reads out the CSR components at batch `index`."
|
||||||
|
description: <<END
|
||||||
|
This op is meant only for debugging / testing, and its interface is not expected
|
||||||
|
to be stable.
|
||||||
|
END
|
||||||
|
}
|
@ -0,0 +1,13 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "CSRSparseMatrixToDense"
|
||||||
|
visibility: HIDDEN
|
||||||
|
in_arg {
|
||||||
|
name: "sparse_input"
|
||||||
|
description: "A batched CSRSparseMatrix."
|
||||||
|
}
|
||||||
|
out_arg {
|
||||||
|
name: "dense_output"
|
||||||
|
description: "A dense tensor."
|
||||||
|
}
|
||||||
|
summary: "Convert a (possibly batched) CSRSparseMatrix to dense."
|
||||||
|
}
|
@ -0,0 +1,20 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "CSRSparseMatrixToSparseTensor"
|
||||||
|
in_arg {
|
||||||
|
name: "sparse_matrix"
|
||||||
|
description: "A (possibly batched) CSRSparseMatrix."
|
||||||
|
}
|
||||||
|
out_arg {
|
||||||
|
name: "indices"
|
||||||
|
description: "SparseTensor indices."
|
||||||
|
}
|
||||||
|
out_arg {
|
||||||
|
name: "values"
|
||||||
|
description: "SparseTensor values."
|
||||||
|
}
|
||||||
|
out_arg {
|
||||||
|
name: "dense_shape"
|
||||||
|
description: "SparseTensor dense shape."
|
||||||
|
}
|
||||||
|
summary: "Converts a (possibly batched) CSRSparesMatrix to a SparseTensor."
|
||||||
|
}
|
@ -0,0 +1,16 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "DenseToCSRSparseMatrix"
|
||||||
|
in_arg {
|
||||||
|
name: "dense_input"
|
||||||
|
description: "A Dense tensor."
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "indices"
|
||||||
|
description: "Indices of nonzero elements."
|
||||||
|
}
|
||||||
|
out_arg {
|
||||||
|
name: "sparse_output"
|
||||||
|
description: "A (possibly batched) CSRSparseMatrix."
|
||||||
|
}
|
||||||
|
summary: "Converts a dense tensor to a (possibly batched) CSRSparseMatrix."
|
||||||
|
}
|
@ -0,0 +1,28 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "SparseMatrixAdd"
|
||||||
|
in_arg {
|
||||||
|
name: "a"
|
||||||
|
description: "A CSRSparseMatrix."
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "b"
|
||||||
|
description: "A CSRSparseMatrix."
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "alpha"
|
||||||
|
description: "A constant scalar."
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "beta"
|
||||||
|
description: "A constant scalar."
|
||||||
|
}
|
||||||
|
out_arg {
|
||||||
|
name: "c"
|
||||||
|
description: "A CSRSparseMatrix."
|
||||||
|
}
|
||||||
|
summary: "Sparse addition of two CSR matrices, C = alpha * A + beta * B."
|
||||||
|
description: <<END
|
||||||
|
The gradients of SparseMatrixAdd outputs with respect to alpha and beta are not
|
||||||
|
currently defined (TensorFlow will return zeros for these entries).
|
||||||
|
END
|
||||||
|
}
|
@ -0,0 +1,67 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "SparseMatrixMatMul"
|
||||||
|
in_arg {
|
||||||
|
name: "a"
|
||||||
|
description: "A CSRSparseMatrix."
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "b"
|
||||||
|
description: "A dense tensor."
|
||||||
|
}
|
||||||
|
out_arg {
|
||||||
|
name: "output"
|
||||||
|
description: "A dense output tensor."
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "transpose_a"
|
||||||
|
description: "Indicates whether `a` should be transposed."
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "transpose_b"
|
||||||
|
description: "Indicates whether `b` should be transposed."
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "adjoint_a"
|
||||||
|
description: "Indicates whether `a` should be conjugate-transposed."
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "adjoint_b"
|
||||||
|
description: "Indicates whether `b` should be conjugate-transposed."
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "transpose_output"
|
||||||
|
description: "Transposes the product of `a` and `b`."
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "conjugate_output"
|
||||||
|
description: "Conjugates the product of `a` and `b`."
|
||||||
|
}
|
||||||
|
summary: "Matrix-multiplies a sparse matrix with a dense matrix."
|
||||||
|
description: <<END
|
||||||
|
Returns a dense matrix.
|
||||||
|
For inputs A and B, where A is CSR and B is dense; this op returns a dense C;
|
||||||
|
|
||||||
|
If transpose_output is false, returns:
|
||||||
|
```
|
||||||
|
C = A . B
|
||||||
|
```
|
||||||
|
|
||||||
|
If transpose_output is `true`, returns:
|
||||||
|
```
|
||||||
|
C = transpose(A . B) = transpose(B) . transpose(A)
|
||||||
|
```
|
||||||
|
where the transposition is performed along the two innermost (matrix)
|
||||||
|
dimensions.
|
||||||
|
|
||||||
|
If conjugate_output is `true`, returns:
|
||||||
|
```
|
||||||
|
C = conjugate(A . B) = conjugate(A) . conjugate(B)
|
||||||
|
```
|
||||||
|
|
||||||
|
If both conjugate_output and transpose_output are `true`, returns:
|
||||||
|
```
|
||||||
|
C = conjugate(transpose(A . B)) = conjugate(transpose(B)) .
|
||||||
|
conjugate(transpose(A))
|
||||||
|
```
|
||||||
|
END
|
||||||
|
}
|
@ -0,0 +1,26 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "SparseMatrixMul"
|
||||||
|
in_arg {
|
||||||
|
name: "a"
|
||||||
|
description: "A CSRSparseMatrix."
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "b"
|
||||||
|
description: "A dense tensor."
|
||||||
|
}
|
||||||
|
out_arg {
|
||||||
|
name: "output"
|
||||||
|
description: "A dense output tensor."
|
||||||
|
}
|
||||||
|
summary: "Element-wise multiplication of a sparse matrix with a dense tensor."
|
||||||
|
description: <<END
|
||||||
|
Returns a sparse matrix.
|
||||||
|
|
||||||
|
The dense tensor `b` may be either a scalar; otherwise `a` must be a rank-3
|
||||||
|
`SparseMatrix`; in this case `b` must be shaped `[batch_size, 1, 1]` and the
|
||||||
|
multiply operation broadcasts.
|
||||||
|
|
||||||
|
**NOTE** even if `b` is zero, the sparsity structure of the output does not
|
||||||
|
change.
|
||||||
|
END
|
||||||
|
}
|
@ -0,0 +1,12 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "SparseMatrixNNZ"
|
||||||
|
in_arg {
|
||||||
|
name: "sparse_matrix"
|
||||||
|
description: "A CSRSparseMatrix."
|
||||||
|
}
|
||||||
|
out_arg {
|
||||||
|
name: "nnz"
|
||||||
|
description: "The number of nonzeroes of `sparse_matrix`."
|
||||||
|
}
|
||||||
|
summary: "Returns the number of nonzeroes of `sparse_matrix`."
|
||||||
|
}
|
@ -0,0 +1,63 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "SparseMatrixOrderingAMD"
|
||||||
|
in_arg {
|
||||||
|
name: "input"
|
||||||
|
description: "A `CSRSparseMatrix`."
|
||||||
|
}
|
||||||
|
out_arg {
|
||||||
|
name: "output"
|
||||||
|
description: "The Approximate Minimum Degree (AMD) ordering of `input`."
|
||||||
|
}
|
||||||
|
summary: "Computes the Approximate Minimum Degree (AMD) ordering of `input`."
|
||||||
|
description: <<END
|
||||||
|
Computes the Approximate Minimum Degree (AMD) ordering for a sparse matrix.
|
||||||
|
|
||||||
|
The returned permutation may be used to permute the rows and columns of the
|
||||||
|
given sparse matrix. This typically results in permuted sparse matrix's sparse
|
||||||
|
Cholesky (or other decompositions) in having fewer zero fill-in compared to
|
||||||
|
decomposition of the original matrix.
|
||||||
|
|
||||||
|
The input sparse matrix may have rank 2 or rank 3. The output Tensor,
|
||||||
|
representing would then have rank 1 or 2 respectively, with the same batch
|
||||||
|
shape as the input.
|
||||||
|
|
||||||
|
Each component of the input sparse matrix must represent a square symmetric
|
||||||
|
matrix; only the lower triangular part of the matrix is read. The values of the
|
||||||
|
sparse matrix does not affect the returned permutation, only the sparsity
|
||||||
|
pattern of the sparse matrix is used. Hence, a single AMD ordering may be
|
||||||
|
reused for the Cholesky decompositions of sparse matrices with the same sparsity
|
||||||
|
pattern but with possibly different values.
|
||||||
|
|
||||||
|
Each batch component of the output permutation represents a permutation of `N`
|
||||||
|
elements, where the input sparse matrix components each have `N` rows. That is,
|
||||||
|
the component contains each of the integers `{0, .. N-1}` exactly once. The
|
||||||
|
`i`th element represents the row index that the `i`th row maps to.
|
||||||
|
|
||||||
|
Usage example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from tensorflow.python.ops.linalg.sparse import sparse_csr_matrix_ops
|
||||||
|
|
||||||
|
a_indices = np.array([[0, 0], [1, 1], [2, 1], [2, 2], [3, 3]])
|
||||||
|
a_values = np.array([1.0, 2.0, 1.0, 3.0, 4.0], np.float32)
|
||||||
|
a_dense_shape = [4, 4]
|
||||||
|
|
||||||
|
with tf.Session() as sess:
|
||||||
|
# Define (COO format) SparseTensor over Numpy array.
|
||||||
|
a_st = tf.SparseTensor(a_indices, a_values, a_dense_shape)
|
||||||
|
|
||||||
|
# Convert SparseTensors to CSR SparseMatrix.
|
||||||
|
a_sm = sparse_csr_matrix_ops.sparse_tensor_to_csr_sparse_matrix(
|
||||||
|
a_st.indices, a_st.values, a_st.dense_shape)
|
||||||
|
|
||||||
|
# Obtain the AMD Ordering for the CSR SparseMatrix.
|
||||||
|
ordering_amd = sparse_csr_matrix_ops.sparse_matrix_ordering_amd(sparse_matrix)
|
||||||
|
|
||||||
|
ordering_amd_value = sess.run(ordering_amd)
|
||||||
|
```
|
||||||
|
|
||||||
|
`ordering_amd_value` stores the AMD ordering: `[1 2 3 0]`.
|
||||||
|
|
||||||
|
input: A `CSRSparseMatrix`.
|
||||||
|
END
|
||||||
|
}
|
@ -0,0 +1,19 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "SparseMatrixSoftmax"
|
||||||
|
in_arg {
|
||||||
|
name: "logits"
|
||||||
|
description: "A CSRSparseMatrix."
|
||||||
|
}
|
||||||
|
out_arg {
|
||||||
|
name: "softmax"
|
||||||
|
description: "A CSRSparseMatrix."
|
||||||
|
}
|
||||||
|
summary: "Calculates the softmax of a CSRSparseMatrix."
|
||||||
|
description: <<END
|
||||||
|
Calculate the softmax of the innermost dimensions of a SparseMatrix.
|
||||||
|
|
||||||
|
Missing values are treated as `-inf` (i.e., logits of zero probability); and
|
||||||
|
the output has the same sparsity structure as the input (though missing values
|
||||||
|
in the output may now be treated as having probability zero).
|
||||||
|
END
|
||||||
|
}
|
@ -0,0 +1,16 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "SparseMatrixSoftmaxGrad"
|
||||||
|
in_arg {
|
||||||
|
name: "softmax"
|
||||||
|
description: "A CSRSparseMatrix."
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "grad_softmax"
|
||||||
|
description: "The gradient of `softmax`."
|
||||||
|
}
|
||||||
|
out_arg {
|
||||||
|
name: "gradient"
|
||||||
|
description: "The output gradient."
|
||||||
|
}
|
||||||
|
summary: "Calculates the gradient of the SparseMatrixSoftmax op."
|
||||||
|
}
|
@ -0,0 +1,95 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "SparseMatrixSparseCholesky"
|
||||||
|
in_arg {
|
||||||
|
name: "input"
|
||||||
|
description: "A `CSRSparseMatrix`."
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "permutation"
|
||||||
|
description: "A fill-in reducing permutation matrix."
|
||||||
|
}
|
||||||
|
out_arg {
|
||||||
|
name: "output"
|
||||||
|
description: "The sparse Cholesky decompsition of `input`."
|
||||||
|
}
|
||||||
|
summary: "Computes the sparse Cholesky decomposition of `input`."
|
||||||
|
description: <<END
|
||||||
|
Computes the Sparse Cholesky decomposition of a sparse matrix, with the given
|
||||||
|
fill-in reducing permutation.
|
||||||
|
|
||||||
|
The input sparse matrix and the fill-in reducing permutation `permutation` must
|
||||||
|
have compatible shapes. If the sparse matrix has rank 3; with the batch
|
||||||
|
dimension `B`, then the `permutation` must be of rank 2; with the same batch
|
||||||
|
dimension `B`. There is no support for broadcasting.
|
||||||
|
|
||||||
|
Furthermore, each component vector of `permutation` must be of length `N`,
|
||||||
|
containing each of the integers {0, 1, ..., N - 1} exactly once, where `N` is
|
||||||
|
the number of rows of each component of the sparse matrix.
|
||||||
|
|
||||||
|
Each component of the input sparse matrix must represent a symmetric positive
|
||||||
|
definite (SPD) matrix; although only the lower triangular part of the matrix is
|
||||||
|
read. If any individual component is not SPD, then an InvalidArgument error is
|
||||||
|
thrown.
|
||||||
|
|
||||||
|
The returned sparse matrix has the same dense shape as the input sparse matrix.
|
||||||
|
For each component `A` of the input sparse matrix, the corresponding output
|
||||||
|
sparse matrix represents `L`, the lower triangular Cholesky factor satisfying
|
||||||
|
the following identity:
|
||||||
|
|
||||||
|
```
|
||||||
|
A = L * Lt
|
||||||
|
```
|
||||||
|
|
||||||
|
where Lt denotes the transpose of L (or its conjugate transpose, if `type` is
|
||||||
|
`complex64` or `complex128`).
|
||||||
|
|
||||||
|
The `type` parameter denotes the type of the matrix elements. The supported
|
||||||
|
types are: `float32`, `float64`, `complex64` and `complex128`.
|
||||||
|
|
||||||
|
Usage example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from tensorflow.python.ops.linalg.sparse import sparse_csr_matrix_ops
|
||||||
|
|
||||||
|
a_indices = np.array([[0, 0], [1, 1], [2, 1], [2, 2], [3, 3]])
|
||||||
|
a_values = np.array([1.0, 2.0, 1.0, 3.0, 4.0], np.float32)
|
||||||
|
a_dense_shape = [4, 4]
|
||||||
|
|
||||||
|
with tf.Session() as sess:
|
||||||
|
# Define (COO format) SparseTensor over Numpy array.
|
||||||
|
a_st = tf.SparseTensor(a_indices, a_values, a_dense_shape)
|
||||||
|
|
||||||
|
# Convert SparseTensors to CSR SparseMatrix.
|
||||||
|
a_sm = sparse_csr_matrix_ops.sparse_tensor_to_csr_sparse_matrix(
|
||||||
|
a_st.indices, a_st.values, a_st.dense_shape)
|
||||||
|
|
||||||
|
# Obtain the Sparse Cholesky factor using AMD Ordering for reducing zero
|
||||||
|
# fill-in (number of structural non-zeros in the sparse Cholesky factor).
|
||||||
|
ordering_amd = sparse_csr_matrix_ops.sparse_matrix_ordering_amd(sparse_matrix)
|
||||||
|
cholesky_sparse_matrices = (
|
||||||
|
sparse_csr_matrix_ops.sparse_matrix_sparse_cholesky(
|
||||||
|
sparse_matrix, ordering_amd, type=tf.float32))
|
||||||
|
|
||||||
|
# Convert the CSRSparseMatrix Cholesky factor to a dense Tensor
|
||||||
|
dense_cholesky = sparse_csr_matrix_ops.csr_sparse_matrix_to_dense(
|
||||||
|
cholesky_sparse_matrices, tf.float32)
|
||||||
|
|
||||||
|
# Evaluate the dense Tensor value.
|
||||||
|
dense_cholesky_value = sess.run(dense_cholesky)
|
||||||
|
```
|
||||||
|
|
||||||
|
`dense_cholesky_value` stores the dense Cholesky factor:
|
||||||
|
|
||||||
|
```
|
||||||
|
[[ 1. 0. 0. 0.]
|
||||||
|
[ 0. 1.41 0. 0.]
|
||||||
|
[ 0. 0.70 1.58 0.]
|
||||||
|
[ 0. 0. 0. 2.]]
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
input: A `CSRSparseMatrix`.
|
||||||
|
permutation: A `Tensor`.
|
||||||
|
type: The type of `input`.
|
||||||
|
END
|
||||||
|
}
|
@ -0,0 +1,110 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "SparseMatrixSparseMatMul"
|
||||||
|
in_arg {
|
||||||
|
name: "a"
|
||||||
|
description: "A CSRSparseMatrix."
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "b"
|
||||||
|
description: "A CSRSparseMatrix."
|
||||||
|
}
|
||||||
|
out_arg {
|
||||||
|
name: "c"
|
||||||
|
description: "A CSRSparseMatrix."
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "transpose_a"
|
||||||
|
description: "Indicates whether `a` should be transposed."
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "transpose_b"
|
||||||
|
description: "Indicates whether `b` should be transposed."
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "adjoint_a"
|
||||||
|
description: "Indicates whether `a` should be conjugate-transposed."
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "adjoint_b"
|
||||||
|
description: "Indicates whether `b` should be conjugate-transposed."
|
||||||
|
}
|
||||||
|
summary: "Sparse-matrix-multiplies two CSR matrices `a` and `b`."
|
||||||
|
description: <<END
|
||||||
|
Performs a matrix multiplication of a sparse matrix `a` with a sparse matrix
|
||||||
|
`b`; returns a sparse matrix `a * b`, unless either `a` or `b` is transposed or
|
||||||
|
adjointed.
|
||||||
|
|
||||||
|
Each matrix may be transposed or adjointed (conjugated and transposed)
|
||||||
|
according to the Boolean parameters `transpose_a`, `adjoint_a`, `transpose_b`
|
||||||
|
and `adjoint_b`. At most one of `transpose_a` or `adjoint_a` may be True.
|
||||||
|
Similarly, at most one of `transpose_b` or `adjoint_b` may be True.
|
||||||
|
|
||||||
|
The inputs must have compatible shapes. That is, the inner dimension of `a`
|
||||||
|
must be equal to the outer dimension of `b`. This requirement is adjusted
|
||||||
|
according to whether either `a` or `b` is transposed or adjointed.
|
||||||
|
|
||||||
|
The `type` parameter denotes the type of the matrix elements. Both `a` and `b`
|
||||||
|
must have the same type. The supported types are: `float32`, `float64`,
|
||||||
|
`complex64` and `complex128`.
|
||||||
|
|
||||||
|
Both `a` and `b` must have the same rank. Broadcasting is not supported. If they
|
||||||
|
have rank 3, each batch of 2D CSRSparseMatrices within `a` and `b` must have the
|
||||||
|
same dense shape.
|
||||||
|
|
||||||
|
The sparse matrix product may have numeric (non-structural) zeros.
|
||||||
|
TODO(anudhyan): Consider adding a boolean attribute to control whether to prune
|
||||||
|
zeros.
|
||||||
|
|
||||||
|
Usage example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from tensorflow.python.ops.linalg.sparse import sparse_csr_matrix_ops
|
||||||
|
|
||||||
|
a_indices = np.array([[0, 0], [2, 3], [2, 4], [3, 0]])
|
||||||
|
a_values = np.array([1.0, 5.0, -1.0, -2.0], np.float32)
|
||||||
|
a_dense_shape = [4, 5]
|
||||||
|
|
||||||
|
b_indices = np.array([[0, 0], [3, 0], [3, 1]])
|
||||||
|
b_values = np.array([2.0, 7.0, 8.0], np.float32)
|
||||||
|
b_dense_shape = [5, 3]
|
||||||
|
|
||||||
|
with tf.Session() as sess:
|
||||||
|
# Define (COO format) Sparse Tensors over Numpy arrays
|
||||||
|
a_st = tf.SparseTensor(a_indices, a_values, a_dense_shape)
|
||||||
|
b_st = tf.SparseTensor(b_indices, b_values, b_dense_shape)
|
||||||
|
|
||||||
|
# Convert SparseTensors to CSR SparseMatrix
|
||||||
|
a_sm = sparse_csr_matrix_ops.sparse_tensor_to_csr_sparse_matrix(
|
||||||
|
a_st.indices, a_st.values, a_st.dense_shape)
|
||||||
|
b_sm = sparse_csr_matrix_ops.sparse_tensor_to_csr_sparse_matrix(
|
||||||
|
b_st.indices, b_st.values, b_st.dense_shape)
|
||||||
|
|
||||||
|
# Compute the CSR SparseMatrix matrix multiplication
|
||||||
|
c_sm = sparse_csr_matrix_ops.sparse_matrix_sparse_mat_mul(
|
||||||
|
a=a_sm, b=b_sm, type=tf.float32)
|
||||||
|
|
||||||
|
# Convert the CSR SparseMatrix product to a dense Tensor
|
||||||
|
c_sm_dense = sparse_csr_matrix_ops.csr_sparse_matrix_to_dense(
|
||||||
|
c_sm, tf.float32)
|
||||||
|
# Evaluate the dense Tensor value
|
||||||
|
c_sm_dense_value = sess.run(c_sm_dense)
|
||||||
|
```
|
||||||
|
|
||||||
|
`c_sm_dense_value` stores the dense matrix product:
|
||||||
|
|
||||||
|
```
|
||||||
|
[[ 2. 0. 0.]
|
||||||
|
[ 0. 0. 0.]
|
||||||
|
[ 35. 40. 0.]
|
||||||
|
[ -4. 0. 0.]]
|
||||||
|
```
|
||||||
|
|
||||||
|
a: A `CSRSparseMatrix`.
|
||||||
|
b: A `CSRSparseMatrix` with the same type and rank as `a`.
|
||||||
|
type: The type of both `a` and `b`.
|
||||||
|
transpose_a: If True, `a` transposed before multiplication.
|
||||||
|
transpose_b: If True, `b` transposed before multiplication.
|
||||||
|
adjoint_a: If True, `a` adjointed before multiplication.
|
||||||
|
adjoint_b: If True, `b` adjointed before multiplication.
|
||||||
|
END
|
||||||
|
}
|
@ -0,0 +1,20 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "SparseMatrixTranspose"
|
||||||
|
in_arg {
|
||||||
|
name: "input"
|
||||||
|
description: "A CSRSparseMatrix."
|
||||||
|
}
|
||||||
|
out_arg {
|
||||||
|
name: "output"
|
||||||
|
description: "A CSRSparseMatrix."
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "conjugate"
|
||||||
|
description: "Indicates whether `input` should be conjugated."
|
||||||
|
}
|
||||||
|
summary: "Transposes the inner (matrix) dimensions of a CSRSparseMatrix."
|
||||||
|
description: <<END
|
||||||
|
Transposes the inner (matrix) dimensions of a SparseMatrix and optionally
|
||||||
|
conjugates its values.
|
||||||
|
END
|
||||||
|
}
|
@ -0,0 +1,12 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "SparseMatrixZeros"
|
||||||
|
in_arg {
|
||||||
|
name: "dense_shape"
|
||||||
|
description: "The desired matrix shape."
|
||||||
|
}
|
||||||
|
out_arg {
|
||||||
|
name: "sparse_matrix"
|
||||||
|
description: "An empty CSR matrix with shape `dense_shape`."
|
||||||
|
}
|
||||||
|
summary: "Creates an all-zeros CSRSparseMatrix with shape `dense_shape`."
|
||||||
|
}
|
@ -0,0 +1,20 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "SparseTensorToCSRSparseMatrix"
|
||||||
|
in_arg {
|
||||||
|
name: "indices"
|
||||||
|
description: "SparseTensor indices."
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "values"
|
||||||
|
description: "SparseTensor values."
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "dense_shape"
|
||||||
|
description: "SparseTensor dense shape."
|
||||||
|
}
|
||||||
|
out_arg {
|
||||||
|
name: "sparse_matrix"
|
||||||
|
description: "A (possibly batched) CSRSparseMatrix."
|
||||||
|
}
|
||||||
|
summary: "Converts a SparseTensor to a (possibly batched) CSRSparseMatrix."
|
||||||
|
}
|
@ -3359,8 +3359,9 @@ tf_kernel_library(
|
|||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core/kernels:cuda_solvers",
|
||||||
"//tensorflow/stream_executor/cuda:cusparse_lib",
|
"//tensorflow/stream_executor/cuda:cusparse_lib",
|
||||||
],
|
] + if_cuda(["@cub_archive//:cub"]),
|
||||||
)
|
)
|
||||||
|
|
||||||
LINALG_DEPS = [
|
LINALG_DEPS = [
|
||||||
|
@ -1,21 +1,22 @@
|
|||||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
You may obtain a copy of the License at
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
Unless required by applicable law or agreed to in writing, software
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
=============================================================================
|
==============================================================================*/
|
||||||
*/
|
|
||||||
|
|
||||||
#ifdef GOOGLE_CUDA
|
#ifdef GOOGLE_CUDA
|
||||||
|
|
||||||
|
#include "tensorflow/core/kernels/cuda_sparse.h"
|
||||||
|
|
||||||
#include <complex>
|
#include <complex>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
@ -26,7 +27,7 @@
|
|||||||
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
|
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/types.h"
|
#include "tensorflow/core/framework/types.h"
|
||||||
#include "tensorflow/core/kernels/cuda_sparse.h"
|
#include "tensorflow/core/kernels/cuda_solvers.h"
|
||||||
#include "tensorflow/core/lib/core/blocking_counter.h"
|
#include "tensorflow/core/lib/core/blocking_counter.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||||
@ -37,6 +38,8 @@
|
|||||||
#include "tensorflow/core/platform/stream_executor.h"
|
#include "tensorflow/core/platform/stream_executor.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
|
// TODO(rmlarsen,penporn): Investigate using newer kernels in CUDA 10.1+.
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
@ -126,6 +129,13 @@ class CudaSparseHandles {
|
|||||||
TF_DISALLOW_COPY_AND_ASSIGN(CudaSparseHandles);
|
TF_DISALLOW_COPY_AND_ASSIGN(CudaSparseHandles);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// TODO(ebrevdo): Replace global mutex guarding CudaSparseHandles
|
||||||
|
// lookup with one of:
|
||||||
|
// 1. Adding the handle to the CudaStream structure; do the lookup there.
|
||||||
|
// 2. Add a thread-local cusparse, set it to the current stream
|
||||||
|
// upon each call.
|
||||||
|
// #1 seems like the cleanest option but will need to wait until this
|
||||||
|
// is moved into TF core.
|
||||||
static mutex handle_map_mutex(LINKER_INITIALIZED);
|
static mutex handle_map_mutex(LINKER_INITIALIZED);
|
||||||
|
|
||||||
using HandleMap = std::unordered_map<cudaStream_t, CudaSparseHandles>;
|
using HandleMap = std::unordered_map<cudaStream_t, CudaSparseHandles>;
|
||||||
@ -141,11 +151,13 @@ HandleMap* GetHandleMapSingleton() {
|
|||||||
|
|
||||||
CudaSparse::CudaSparse(OpKernelContext* context)
|
CudaSparse::CudaSparse(OpKernelContext* context)
|
||||||
: initialized_(false), context_(context) {
|
: initialized_(false), context_(context) {
|
||||||
cuda_stream_ =
|
auto cuda_stream_ptr =
|
||||||
*reinterpret_cast<const cudaStream_t*>(context->op_device_context()
|
reinterpret_cast<const cudaStream_t*>(context->op_device_context()
|
||||||
->stream()
|
->stream()
|
||||||
->implementation()
|
->implementation()
|
||||||
->GpuStreamMemberHack());
|
->GpuStreamMemberHack());
|
||||||
|
DCHECK(cuda_stream_ptr);
|
||||||
|
cuda_stream_ = *cuda_stream_ptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CudaSparse::Initialize() {
|
Status CudaSparse::Initialize() {
|
||||||
@ -376,6 +388,303 @@ static inline Status Gtsv2StridedBatchBufferSizeImpl(
|
|||||||
|
|
||||||
TF_CALL_LAPACK_TYPES(GTSV2_STRIDED_BATCH_BUFFER_SIZE_INSTANCE);
|
TF_CALL_LAPACK_TYPES(GTSV2_STRIDED_BATCH_BUFFER_SIZE_INSTANCE);
|
||||||
|
|
||||||
|
Status CudaSparse::Coo2csr(const int* cooRowInd, int nnz, int m,
|
||||||
|
int* csrRowPtr) const {
|
||||||
|
// cusparseStatus_t CUSPARSEAPI cusparseXcoo2csr(cusparseHandle_t handle,
|
||||||
|
// const int *cooRowInd,
|
||||||
|
// int nnz,
|
||||||
|
// int m,
|
||||||
|
// int *csrSortedRowPtr,
|
||||||
|
// cusparseIndexBase_t
|
||||||
|
// idxBase);
|
||||||
|
DCHECK(initialized_);
|
||||||
|
TF_RETURN_IF_CUSPARSE_ERROR(cusparseXcoo2csr(*cusparse_handle_, cooRowInd,
|
||||||
|
nnz, m, csrRowPtr,
|
||||||
|
CUSPARSE_INDEX_BASE_ZERO));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status CudaSparse::Csr2coo(const int* csrRowPtr, int nnz, int m,
|
||||||
|
int* cooRowInd) const {
|
||||||
|
// cusparseStatus_t CUSPARSEAPI cusparseXcsr2coo(cusparseHandle_t handle,
|
||||||
|
// const int *csrRowPtr,
|
||||||
|
// int nnz,
|
||||||
|
// int m,
|
||||||
|
// int *cooRowInd,
|
||||||
|
// cusparseIndexBase_t
|
||||||
|
// idxBase);
|
||||||
|
DCHECK(initialized_);
|
||||||
|
TF_RETURN_IF_CUSPARSE_ERROR(cusparseXcsr2coo(*cusparse_handle_, csrRowPtr,
|
||||||
|
nnz, m, cooRowInd,
|
||||||
|
CUSPARSE_INDEX_BASE_ZERO));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status CudaSparse::CsrgeamNnz(int m, int n, const cusparseMatDescr_t descrA,
|
||||||
|
int nnzA, const int* csrSortedRowPtrA,
|
||||||
|
const int* csrSortedColIndA,
|
||||||
|
const cusparseMatDescr_t descrB, int nnzB,
|
||||||
|
const int* csrSortedRowPtrB,
|
||||||
|
const int* csrSortedColIndB,
|
||||||
|
const cusparseMatDescr_t descrC,
|
||||||
|
int* csrSortedRowPtrC, int* nnzTotalDevHostPtr) {
|
||||||
|
DCHECK(initialized_);
|
||||||
|
DCHECK(nnzTotalDevHostPtr != nullptr);
|
||||||
|
TF_RETURN_IF_CUSPARSE_ERROR(cusparseXcsrgeamNnz(
|
||||||
|
*cusparse_handle_, m, n, descrA, nnzA, csrSortedRowPtrA, csrSortedColIndA,
|
||||||
|
descrB, nnzB, csrSortedRowPtrB, csrSortedColIndB, descrC,
|
||||||
|
csrSortedRowPtrC, nnzTotalDevHostPtr));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Scalar, typename SparseFnT>
|
||||||
|
static inline Status CsrmmImpl(
|
||||||
|
SparseFnT op, OpKernelContext* context, cusparseHandle_t cusparse_handle,
|
||||||
|
cusparseOperation_t transA, cusparseOperation_t transB, int m, int n, int k,
|
||||||
|
int nnz, const Scalar* alpha_host, const cusparseMatDescr_t descrA,
|
||||||
|
const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
|
||||||
|
const int* csrSortedColIndA, const Scalar* B, int ldb,
|
||||||
|
const Scalar* beta_host, Scalar* C, int ldc) {
|
||||||
|
// cusparseStatus_t CUSPARSEAPI cusparseScsrmm2(
|
||||||
|
// cusparseHandle_t handle, cusparseOperation_t transA,
|
||||||
|
// cusparseOperation_t transB, int m, int n, int k, int nnz,
|
||||||
|
// const float* alpha, const cusparseMatDescr_t descrA,
|
||||||
|
// const float* csrSortedValA, const int* csrSortedRowPtrA,
|
||||||
|
// const int* csrSortedColIndA, const float* B, int ldb, const float*
|
||||||
|
// beta, float* C, int ldc);
|
||||||
|
TF_RETURN_IF_CUSPARSE_ERROR(op(
|
||||||
|
cusparse_handle, transA, transB, m, n, k, nnz, AsCudaComplex(alpha_host),
|
||||||
|
descrA, AsCudaComplex(csrSortedValA), csrSortedRowPtrA, csrSortedColIndA,
|
||||||
|
AsCudaComplex(B), ldb, AsCudaComplex(beta_host), AsCudaComplex(C), ldc));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
#define CSRMM_INSTANCE(Scalar, sparse_prefix) \
|
||||||
|
template <> \
|
||||||
|
Status CudaSparse::Csrmm<Scalar>( \
|
||||||
|
cusparseOperation_t transA, cusparseOperation_t transB, int m, int n, \
|
||||||
|
int k, int nnz, const Scalar* alpha_host, \
|
||||||
|
const cusparseMatDescr_t descrA, const Scalar* csrSortedValA, \
|
||||||
|
const int* csrSortedRowPtrA, const int* csrSortedColIndA, \
|
||||||
|
const Scalar* B, int ldb, const Scalar* beta_host, Scalar* C, int ldc) \
|
||||||
|
const { \
|
||||||
|
DCHECK(initialized_); \
|
||||||
|
return CsrmmImpl(SPARSE_FN(csrmm2, sparse_prefix), context_, \
|
||||||
|
*cusparse_handle_, transA, transB, m, n, k, nnz, \
|
||||||
|
alpha_host, descrA, csrSortedValA, csrSortedRowPtrA, \
|
||||||
|
csrSortedColIndA, B, ldb, beta_host, C, ldc); \
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_CALL_LAPACK_TYPES(CSRMM_INSTANCE);
|
||||||
|
|
||||||
|
template <typename Scalar, typename SparseFnT>
|
||||||
|
static inline Status CsrmvImpl(
|
||||||
|
SparseFnT op, OpKernelContext* context, cusparseHandle_t cusparse_handle,
|
||||||
|
cusparseOperation_t transA, int m, int n, int nnz, const Scalar* alpha_host,
|
||||||
|
const cusparseMatDescr_t descrA, const Scalar* csrSortedValA,
|
||||||
|
const int* csrSortedRowPtrA, const int* csrSortedColIndA, const Scalar* x,
|
||||||
|
const Scalar* beta_host, Scalar* y) {
|
||||||
|
TF_RETURN_IF_CUSPARSE_ERROR(
|
||||||
|
op(cusparse_handle, transA, m, n, nnz, AsCudaComplex(alpha_host), descrA,
|
||||||
|
AsCudaComplex(csrSortedValA), csrSortedRowPtrA, csrSortedColIndA,
|
||||||
|
AsCudaComplex(x), AsCudaComplex(beta_host), AsCudaComplex(y)));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(ebrevdo,rmlarsen): Use csrmv_mp for all cases when available in CUDA 9.
|
||||||
|
#define CSRMV_INSTANCE(Scalar, sparse_prefix) \
|
||||||
|
template <> \
|
||||||
|
Status CudaSparse::Csrmv<Scalar>( \
|
||||||
|
cusparseOperation_t transA, int m, int n, int nnz, \
|
||||||
|
const Scalar* alpha_host, const cusparseMatDescr_t descrA, \
|
||||||
|
const Scalar* csrSortedValA, const int* csrSortedRowPtrA, \
|
||||||
|
const int* csrSortedColIndA, const Scalar* x, const Scalar* beta_host, \
|
||||||
|
Scalar* y) const { \
|
||||||
|
DCHECK(initialized_); \
|
||||||
|
if (transA == CUSPARSE_OPERATION_NON_TRANSPOSE) { \
|
||||||
|
return CsrmvImpl(SPARSE_FN(csrmv_mp, sparse_prefix), context_, \
|
||||||
|
*cusparse_handle_, transA, m, n, nnz, alpha_host, \
|
||||||
|
descrA, csrSortedValA, csrSortedRowPtrA, \
|
||||||
|
csrSortedColIndA, x, beta_host, y); \
|
||||||
|
} else { \
|
||||||
|
return CsrmvImpl(SPARSE_FN(csrmv, sparse_prefix), context_, \
|
||||||
|
*cusparse_handle_, transA, m, n, nnz, alpha_host, \
|
||||||
|
descrA, csrSortedValA, csrSortedRowPtrA, \
|
||||||
|
csrSortedColIndA, x, beta_host, y); \
|
||||||
|
} \
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_CALL_LAPACK_TYPES(CSRMV_INSTANCE);
|
||||||
|
|
||||||
|
template <typename Scalar, typename SparseFnT>
|
||||||
|
static inline Status CsrgeamImpl(
|
||||||
|
SparseFnT op, OpKernelContext* context, cusparseHandle_t cusparse_handle,
|
||||||
|
int m, int n, const Scalar* alpha, const cusparseMatDescr_t descrA,
|
||||||
|
int nnzA, const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
|
||||||
|
const int* csrSortedColIndA, const Scalar* beta,
|
||||||
|
const cusparseMatDescr_t descrB, int nnzB, const Scalar* csrSortedValB,
|
||||||
|
const int* csrSortedRowPtrB, const int* csrSortedColIndB,
|
||||||
|
const cusparseMatDescr_t descrC, Scalar* csrSortedValC,
|
||||||
|
int* csrSortedRowPtrC, int* csrSortedColIndC) {
|
||||||
|
TF_RETURN_IF_CUSPARSE_ERROR(
|
||||||
|
op(cusparse_handle, m, n, AsCudaComplex(alpha), descrA, nnzA,
|
||||||
|
AsCudaComplex(csrSortedValA), csrSortedRowPtrA, csrSortedColIndA,
|
||||||
|
AsCudaComplex(beta), descrB, nnzB, AsCudaComplex(csrSortedValB),
|
||||||
|
csrSortedRowPtrB, csrSortedColIndB, descrC,
|
||||||
|
AsCudaComplex(csrSortedValC), csrSortedRowPtrC, csrSortedColIndC));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
#define CSRGEAM_INSTANCE(Scalar, sparse_prefix) \
|
||||||
|
template <> \
|
||||||
|
Status CudaSparse::Csrgeam<Scalar>( \
|
||||||
|
int m, int n, const Scalar* alpha, const cusparseMatDescr_t descrA, \
|
||||||
|
int nnzA, const Scalar* csrSortedValA, const int* csrSortedRowPtrA, \
|
||||||
|
const int* csrSortedColIndA, const Scalar* beta, \
|
||||||
|
const cusparseMatDescr_t descrB, int nnzB, const Scalar* csrSortedValB, \
|
||||||
|
const int* csrSortedRowPtrB, const int* csrSortedColIndB, \
|
||||||
|
const cusparseMatDescr_t descrC, Scalar* csrSortedValC, \
|
||||||
|
int* csrSortedRowPtrC, int* csrSortedColIndC) { \
|
||||||
|
DCHECK(initialized_); \
|
||||||
|
return CsrgeamImpl(SPARSE_FN(csrgeam, sparse_prefix), context_, \
|
||||||
|
*cusparse_handle_, m, n, alpha, descrA, nnzA, \
|
||||||
|
csrSortedValA, csrSortedRowPtrA, csrSortedColIndA, \
|
||||||
|
beta, descrB, nnzB, csrSortedValB, csrSortedRowPtrB, \
|
||||||
|
csrSortedColIndB, descrC, csrSortedValC, \
|
||||||
|
csrSortedRowPtrC, csrSortedColIndC); \
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_CALL_LAPACK_TYPES(CSRGEAM_INSTANCE);
|
||||||
|
|
||||||
|
Status CudaSparse::CsrgemmNnz(
|
||||||
|
cusparseOperation_t transA, cusparseOperation_t transB, int m, int k, int n,
|
||||||
|
const cusparseMatDescr_t descrA, int nnzA, const int* csrSortedRowPtrA,
|
||||||
|
const int* csrSortedColIndA, const cusparseMatDescr_t descrB, int nnzB,
|
||||||
|
const int* csrSortedRowPtrB, const int* csrSortedColIndB,
|
||||||
|
const cusparseMatDescr_t descrC, int* csrSortedRowPtrC,
|
||||||
|
int* nnzTotalDevHostPtr) {
|
||||||
|
DCHECK(initialized_);
|
||||||
|
DCHECK(nnzTotalDevHostPtr != nullptr);
|
||||||
|
TF_RETURN_IF_CUSPARSE_ERROR(cusparseXcsrgemmNnz(
|
||||||
|
*cusparse_handle_, transA, transB, m, k, n, descrA, nnzA,
|
||||||
|
csrSortedRowPtrA, csrSortedColIndA, descrB, nnzB, csrSortedRowPtrB,
|
||||||
|
csrSortedColIndB, descrC, csrSortedRowPtrC, nnzTotalDevHostPtr));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Scalar, typename SparseFnT>
|
||||||
|
static inline Status CsrgemmImpl(
|
||||||
|
SparseFnT op, OpKernelContext* context, cusparseHandle_t cusparse_handle,
|
||||||
|
cusparseOperation_t transA, cusparseOperation_t transB, int m, int k, int n,
|
||||||
|
const cusparseMatDescr_t descrA, int nnzA, const Scalar* csrSortedValA,
|
||||||
|
const int* csrSortedRowPtrA, const int* csrSortedColIndA,
|
||||||
|
const cusparseMatDescr_t descrB, int nnzB, const Scalar* csrSortedValB,
|
||||||
|
const int* csrSortedRowPtrB, const int* csrSortedColIndB,
|
||||||
|
const cusparseMatDescr_t descrC, Scalar* csrSortedValC,
|
||||||
|
int* csrSortedRowPtrC, int* csrSortedColIndC) {
|
||||||
|
TF_RETURN_IF_CUSPARSE_ERROR(
|
||||||
|
op(cusparse_handle, transA, transB, m, k, n, descrA, nnzA,
|
||||||
|
AsCudaComplex(csrSortedValA), csrSortedRowPtrA, csrSortedColIndA,
|
||||||
|
descrB, nnzB, AsCudaComplex(csrSortedValB), csrSortedRowPtrB,
|
||||||
|
csrSortedColIndB, descrC, AsCudaComplex(csrSortedValC),
|
||||||
|
csrSortedRowPtrC, csrSortedColIndC));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
#define CSRGEMM_INSTANCE(Scalar, sparse_prefix) \
|
||||||
|
template <> \
|
||||||
|
Status CudaSparse::Csrgemm<Scalar>( \
|
||||||
|
cusparseOperation_t transA, cusparseOperation_t transB, int m, int k, \
|
||||||
|
int n, const cusparseMatDescr_t descrA, int nnzA, \
|
||||||
|
const Scalar* csrSortedValA, const int* csrSortedRowPtrA, \
|
||||||
|
const int* csrSortedColIndA, const cusparseMatDescr_t descrB, int nnzB, \
|
||||||
|
const Scalar* csrSortedValB, const int* csrSortedRowPtrB, \
|
||||||
|
const int* csrSortedColIndB, const cusparseMatDescr_t descrC, \
|
||||||
|
Scalar* csrSortedValC, int* csrSortedRowPtrC, int* csrSortedColIndC) { \
|
||||||
|
DCHECK(initialized_); \
|
||||||
|
return CsrgemmImpl(SPARSE_FN(csrgemm, sparse_prefix), context_, \
|
||||||
|
*cusparse_handle_, transA, transB, m, k, n, descrA, \
|
||||||
|
nnzA, csrSortedValA, csrSortedRowPtrA, \
|
||||||
|
csrSortedColIndA, descrB, nnzB, csrSortedValB, \
|
||||||
|
csrSortedRowPtrB, csrSortedColIndB, descrC, \
|
||||||
|
csrSortedValC, csrSortedRowPtrC, csrSortedColIndC); \
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_CALL_LAPACK_TYPES(CSRGEMM_INSTANCE);
|
||||||
|
|
||||||
|
template <typename Scalar, typename BufferSizeFnT, typename SparseFnT>
|
||||||
|
static inline Status Csru2csrImpl(SparseFnT op, BufferSizeFnT buffer_size_op,
|
||||||
|
OpKernelContext* context,
|
||||||
|
cusparseHandle_t cusparse_handle, int m,
|
||||||
|
int n, int nnz,
|
||||||
|
const cusparseMatDescr_t descrA,
|
||||||
|
Scalar* csrVal, const int* csrRowPtr,
|
||||||
|
int* csrColInd) {
|
||||||
|
CudaSparseCsrSortingConversionInfo info;
|
||||||
|
TF_RETURN_IF_ERROR(info.Initialize());
|
||||||
|
|
||||||
|
size_t pBufferSizeInBytes = 0;
|
||||||
|
|
||||||
|
TF_RETURN_IF_CUSPARSE_ERROR(
|
||||||
|
buffer_size_op(cusparse_handle, m, n, nnz, AsCudaComplex(csrVal),
|
||||||
|
csrRowPtr, csrColInd, info.info(), &pBufferSizeInBytes));
|
||||||
|
|
||||||
|
Tensor pBuffer_t;
|
||||||
|
TF_RETURN_IF_ERROR(context->allocate_temp(
|
||||||
|
DT_INT8, TensorShape({static_cast<int64>(pBufferSizeInBytes)}),
|
||||||
|
&pBuffer_t));
|
||||||
|
auto pBuffer = pBuffer_t.flat<int8>();
|
||||||
|
DCHECK(pBuffer.data() != nullptr);
|
||||||
|
|
||||||
|
TF_RETURN_IF_CUSPARSE_ERROR(op(cusparse_handle, m, n, nnz, descrA,
|
||||||
|
AsCudaComplex(csrVal), csrRowPtr, csrColInd,
|
||||||
|
info.info(), pBuffer.data()));
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
#define CSRU2CSR_INSTANCE(Scalar, sparse_prefix) \
|
||||||
|
template <> \
|
||||||
|
Status CudaSparse::Csru2csr<Scalar>( \
|
||||||
|
int m, int n, int nnz, const cusparseMatDescr_t descrA, Scalar* csrVal, \
|
||||||
|
const int* csrRowPtr, int* csrColInd) { \
|
||||||
|
DCHECK(initialized_); \
|
||||||
|
return Csru2csrImpl(SPARSE_FN(csru2csr, sparse_prefix), \
|
||||||
|
BUFSIZE_FN(csru2csr, sparse_prefix), context_, \
|
||||||
|
*cusparse_handle_, m, n, nnz, descrA, csrVal, \
|
||||||
|
csrRowPtr, csrColInd); \
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_CALL_LAPACK_TYPES(CSRU2CSR_INSTANCE);
|
||||||
|
|
||||||
|
template <typename Scalar, typename SparseFnT>
|
||||||
|
static inline Status Csr2cscImpl(SparseFnT op, OpKernelContext* context,
|
||||||
|
cusparseHandle_t cusparse_handle, int m, int n,
|
||||||
|
int nnz, const Scalar* csrVal,
|
||||||
|
const int* csrRowPtr, const int* csrColInd,
|
||||||
|
Scalar* cscVal, int* cscRowInd, int* cscColPtr,
|
||||||
|
const cusparseAction_t copyValues) {
|
||||||
|
TF_RETURN_IF_CUSPARSE_ERROR(op(cusparse_handle, m, n, nnz,
|
||||||
|
AsCudaComplex(csrVal), csrRowPtr, csrColInd,
|
||||||
|
AsCudaComplex(cscVal), cscRowInd, cscColPtr,
|
||||||
|
copyValues, CUSPARSE_INDEX_BASE_ZERO));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
#define CSR2CSC_INSTANCE(Scalar, sparse_prefix) \
|
||||||
|
template <> \
|
||||||
|
Status CudaSparse::Csr2csc<Scalar>( \
|
||||||
|
int m, int n, int nnz, const Scalar* csrVal, const int* csrRowPtr, \
|
||||||
|
const int* csrColInd, Scalar* cscVal, int* cscRowInd, int* cscColPtr, \
|
||||||
|
const cusparseAction_t copyValues) { \
|
||||||
|
DCHECK(initialized_); \
|
||||||
|
return Csr2cscImpl(SPARSE_FN(csr2csc, sparse_prefix), context_, \
|
||||||
|
*cusparse_handle_, m, n, nnz, csrVal, csrRowPtr, \
|
||||||
|
csrColInd, cscVal, cscRowInd, cscColPtr, copyValues); \
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_CALL_LAPACK_TYPES(CSR2CSC_INSTANCE);
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA
|
||||||
|
@ -11,8 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS,
|
|||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================
|
==============================================================================*/
|
||||||
*/
|
|
||||||
|
|
||||||
#ifndef TENSORFLOW_CORE_KERNELS_CUDA_SPARSE_H_
|
#ifndef TENSORFLOW_CORE_KERNELS_CUDA_SPARSE_H_
|
||||||
#define TENSORFLOW_CORE_KERNELS_CUDA_SPARSE_H_
|
#define TENSORFLOW_CORE_KERNELS_CUDA_SPARSE_H_
|
||||||
@ -76,6 +75,22 @@ inline string ConvertCUSparseErrorToString(const cusparseStatus_t status) {
|
|||||||
} \
|
} \
|
||||||
} while (0)
|
} while (0)
|
||||||
|
|
||||||
|
inline cusparseOperation_t TransposeAndConjugateToCuSparseOp(bool transpose,
|
||||||
|
bool conjugate,
|
||||||
|
Status* status) {
|
||||||
|
if (transpose) {
|
||||||
|
return conjugate ? CUSPARSE_OPERATION_CONJUGATE_TRANSPOSE
|
||||||
|
: CUSPARSE_OPERATION_TRANSPOSE;
|
||||||
|
} else {
|
||||||
|
if (conjugate) {
|
||||||
|
DCHECK(status != nullptr);
|
||||||
|
*status = errors::InvalidArgument(
|
||||||
|
"Conjugate == True and transpose == False is not supported.");
|
||||||
|
}
|
||||||
|
return CUSPARSE_OPERATION_NON_TRANSPOSE;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// The CudaSparse class provides a simplified templated API for cuSparse
|
// The CudaSparse class provides a simplified templated API for cuSparse
|
||||||
// (http://docs.nvidia.com/cuda/cusparse/index.html).
|
// (http://docs.nvidia.com/cuda/cusparse/index.html).
|
||||||
// An object of this class wraps static cuSparse instances,
|
// An object of this class wraps static cuSparse instances,
|
||||||
@ -89,7 +104,7 @@ inline string ConvertCUSparseErrorToString(const cusparseStatus_t status) {
|
|||||||
class CudaSparse {
|
class CudaSparse {
|
||||||
public:
|
public:
|
||||||
// This object stores a pointer to context, which must outlive it.
|
// This object stores a pointer to context, which must outlive it.
|
||||||
explicit CudaSparse(OpKernelContext *context);
|
explicit CudaSparse(OpKernelContext* context);
|
||||||
virtual ~CudaSparse() {}
|
virtual ~CudaSparse() {}
|
||||||
|
|
||||||
// This initializes the CudaSparse class if it hasn't
|
// This initializes the CudaSparse class if it hasn't
|
||||||
@ -180,6 +195,119 @@ class CudaSparse {
|
|||||||
int batchStride,
|
int batchStride,
|
||||||
size_t *bufferSizeInBytes) const;
|
size_t *bufferSizeInBytes) const;
|
||||||
|
|
||||||
|
// Compresses the indices of rows or columns. It can be interpreted as a
|
||||||
|
// conversion from COO to CSR sparse storage format. See:
|
||||||
|
// http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csr2coo.
|
||||||
|
Status Csr2coo(const int* CsrRowPtr, int nnz, int m, int* cooRowInd) const;
|
||||||
|
|
||||||
|
// Uncompresses the indices of rows or columns. It can be interpreted as a
|
||||||
|
// conversion from CSR to COO sparse storage format. See:
|
||||||
|
// http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-coo2csr.
|
||||||
|
Status Coo2csr(const int* cooRowInd, int nnz, int m, int* csrRowPtr) const;
|
||||||
|
|
||||||
|
// Sparse-dense matrix multiplication C = alpha * op(A) * op(B) + beta * C,
|
||||||
|
// where A is a sparse matrix in CSR format, B and C are dense tall
|
||||||
|
// matrices. This routine allows transposition of matrix B, which
|
||||||
|
// may improve performance. See:
|
||||||
|
// http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrmm2
|
||||||
|
//
|
||||||
|
// **NOTE** Matrices B and C are expected to be in column-major
|
||||||
|
// order; to make them consistent with TensorFlow they
|
||||||
|
// must be transposed (or the matmul op's pre/post-procesisng must take this
|
||||||
|
// into account).
|
||||||
|
//
|
||||||
|
// **NOTE** This is an in-place operation for data in C.
|
||||||
|
template <typename Scalar>
|
||||||
|
Status Csrmm(cusparseOperation_t transA, cusparseOperation_t transB, int m,
|
||||||
|
int n, int k, int nnz, const Scalar* alpha_host,
|
||||||
|
const cusparseMatDescr_t descrA, const Scalar* csrSortedValA,
|
||||||
|
const int* csrSortedRowPtrA, const int* csrSortedColIndA,
|
||||||
|
const Scalar* B, int ldb, const Scalar* beta_host, Scalar* C,
|
||||||
|
int ldc) const;
|
||||||
|
|
||||||
|
// Sparse-dense vector multiplication y = alpha * op(A) * x + beta * y,
|
||||||
|
// where A is a sparse matrix in CSR format, x and y are dense vectors. See:
|
||||||
|
// http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrmv_mergepath
|
||||||
|
//
|
||||||
|
// **NOTE** This is an in-place operation for data in y.
|
||||||
|
template <typename Scalar>
|
||||||
|
Status Csrmv(cusparseOperation_t transA, int m, int n, int nnz,
|
||||||
|
const Scalar* alpha_host, const cusparseMatDescr_t descrA,
|
||||||
|
const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
|
||||||
|
const int* csrSortedColIndA, const Scalar* x,
|
||||||
|
const Scalar* beta_host, Scalar* y) const;
|
||||||
|
|
||||||
|
// Computes sparse-sparse matrix addition of matrices
|
||||||
|
// stored in CSR format. This is part one: calculate nnz of the
|
||||||
|
// output. csrSortedRowPtrC must be preallocated on device with
|
||||||
|
// m + 1 entries. See:
|
||||||
|
// http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrgeam.
|
||||||
|
Status CsrgeamNnz(int m, int n, const cusparseMatDescr_t descrA, int nnzA,
|
||||||
|
const int* csrSortedRowPtrA, const int* csrSortedColIndA,
|
||||||
|
const cusparseMatDescr_t descrB, int nnzB,
|
||||||
|
const int* csrSortedRowPtrB, const int* csrSortedColIndB,
|
||||||
|
const cusparseMatDescr_t descrC, int* csrSortedRowPtrC,
|
||||||
|
int* nnzTotalDevHostPtr);
|
||||||
|
|
||||||
|
// Computes sparse - sparse matrix addition of matrices
|
||||||
|
// stored in CSR format. This is part two: perform sparse-sparse
|
||||||
|
// addition. csrValC and csrColIndC must be allocated on the device
|
||||||
|
// with nnzTotalDevHostPtr entries (as calculated by CsrgeamNnz). See:
|
||||||
|
// http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrgeam.
|
||||||
|
template <typename Scalar>
|
||||||
|
Status Csrgeam(int m, int n, const Scalar* alpha,
|
||||||
|
const cusparseMatDescr_t descrA, int nnzA,
|
||||||
|
const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
|
||||||
|
const int* csrSortedColIndA, const Scalar* beta,
|
||||||
|
const cusparseMatDescr_t descrB, int nnzB,
|
||||||
|
const Scalar* csrSortedValB, const int* csrSortedRowPtrB,
|
||||||
|
const int* csrSortedColIndB, const cusparseMatDescr_t descrC,
|
||||||
|
Scalar* csrSortedValC, int* csrSortedRowPtrC,
|
||||||
|
int* csrSortedColIndC);
|
||||||
|
|
||||||
|
// Computes sparse-sparse matrix multiplication of matrices
|
||||||
|
// stored in CSR format. This is part one: calculate nnz of the
|
||||||
|
// output. csrSortedRowPtrC must be preallocated on device with
|
||||||
|
// m + 1 entries. See:
|
||||||
|
// http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrgemm.
|
||||||
|
Status CsrgemmNnz(cusparseOperation_t transA, cusparseOperation_t transB,
|
||||||
|
int m, int k, int n, const cusparseMatDescr_t descrA,
|
||||||
|
int nnzA, const int* csrSortedRowPtrA,
|
||||||
|
const int* csrSortedColIndA,
|
||||||
|
const cusparseMatDescr_t descrB, int nnzB,
|
||||||
|
const int* csrSortedRowPtrB, const int* csrSortedColIndB,
|
||||||
|
const cusparseMatDescr_t descrC, int* csrSortedRowPtrC,
|
||||||
|
int* nnzTotalDevHostPtr);
|
||||||
|
|
||||||
|
// Computes sparse - sparse matrix matmul of matrices
|
||||||
|
// stored in CSR format. This is part two: perform sparse-sparse
|
||||||
|
// addition. csrValC and csrColIndC must be allocated on the device
|
||||||
|
// with nnzTotalDevHostPtr entries (as calculated by CsrgemmNnz). See:
|
||||||
|
// http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrgemm.
|
||||||
|
template <typename Scalar>
|
||||||
|
Status Csrgemm(cusparseOperation_t transA, cusparseOperation_t transB, int m,
|
||||||
|
int k, int n, const cusparseMatDescr_t descrA, int nnzA,
|
||||||
|
const Scalar* csrSortedValA, const int* csrSortedRowPtrA,
|
||||||
|
const int* csrSortedColIndA, const cusparseMatDescr_t descrB,
|
||||||
|
int nnzB, const Scalar* csrSortedValB,
|
||||||
|
const int* csrSortedRowPtrB, const int* csrSortedColIndB,
|
||||||
|
const cusparseMatDescr_t descrC, Scalar* csrSortedValC,
|
||||||
|
int* csrSortedRowPtrC, int* csrSortedColIndC);
|
||||||
|
|
||||||
|
// In-place reordering of unsorted CSR to sorted CSR.
|
||||||
|
// http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csru2csr
|
||||||
|
template <typename Scalar>
|
||||||
|
Status Csru2csr(int m, int n, int nnz, const cusparseMatDescr_t descrA,
|
||||||
|
Scalar* csrVal, const int* csrRowPtr, int* csrColInd);
|
||||||
|
|
||||||
|
// Converts from CSR to CSC format (equivalently, transpose).
|
||||||
|
// http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-csr2cscEx
|
||||||
|
template <typename Scalar>
|
||||||
|
Status Csr2csc(int m, int n, int nnz, const Scalar* csrVal,
|
||||||
|
const int* csrRowPtr, const int* csrColInd, Scalar* cscVal,
|
||||||
|
int* cscRowInd, int* cscColPtr,
|
||||||
|
const cusparseAction_t copyValues);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
bool initialized_;
|
bool initialized_;
|
||||||
OpKernelContext *context_; // not owned.
|
OpKernelContext *context_; // not owned.
|
||||||
@ -189,6 +317,119 @@ class CudaSparse {
|
|||||||
TF_DISALLOW_COPY_AND_ASSIGN(CudaSparse);
|
TF_DISALLOW_COPY_AND_ASSIGN(CudaSparse);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// A wrapper class to ensure that a CUDA sparse matrix descriptor is initialized
|
||||||
|
// only once. For more details on the descriptor (cusparseMatDescr_t), see:
|
||||||
|
// https://docs.nvidia.com/cuda/cusparse/index.html#cusparsematdescrt
|
||||||
|
class CudaSparseMatrixDescriptor {
|
||||||
|
public:
|
||||||
|
explicit CudaSparseMatrixDescriptor() : initialized_(false) {}
|
||||||
|
|
||||||
|
CudaSparseMatrixDescriptor(CudaSparseMatrixDescriptor&& rhs)
|
||||||
|
: initialized_(rhs.initialized_), descr_(std::move(rhs.descr_)) {
|
||||||
|
rhs.initialized_ = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
CudaSparseMatrixDescriptor& operator=(CudaSparseMatrixDescriptor&& rhs) {
|
||||||
|
if (this == &rhs) return *this;
|
||||||
|
Release();
|
||||||
|
initialized_ = rhs.initialized_;
|
||||||
|
descr_ = std::move(rhs.descr_);
|
||||||
|
rhs.initialized_ = false;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
~CudaSparseMatrixDescriptor() { Release(); }
|
||||||
|
|
||||||
|
// Initializes the underlying descriptor. Will fail on the second call if
|
||||||
|
// called more than once.
|
||||||
|
Status Initialize() {
|
||||||
|
DCHECK(!initialized_);
|
||||||
|
TF_RETURN_IF_CUSPARSE_ERROR(cusparseCreateMatDescr(&descr_));
|
||||||
|
initialized_ = true;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
cusparseMatDescr_t& descr() {
|
||||||
|
DCHECK(initialized_);
|
||||||
|
return descr_;
|
||||||
|
}
|
||||||
|
|
||||||
|
const cusparseMatDescr_t& descr() const {
|
||||||
|
DCHECK(initialized_);
|
||||||
|
return descr_;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
void Release() {
|
||||||
|
if (initialized_) {
|
||||||
|
cusparseDestroyMatDescr(descr_);
|
||||||
|
initialized_ = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool initialized_;
|
||||||
|
cusparseMatDescr_t descr_;
|
||||||
|
|
||||||
|
TF_DISALLOW_COPY_AND_ASSIGN(CudaSparseMatrixDescriptor);
|
||||||
|
};
|
||||||
|
|
||||||
|
// A wrapper class to ensure that an unsorted/sorted CSR conversion information
|
||||||
|
// struct (csru2csrInfo_t) is initialized only once. See:
|
||||||
|
// https://docs.nvidia.com/cuda/cusparse/index.html#csru2csr
|
||||||
|
class CudaSparseCsrSortingConversionInfo {
|
||||||
|
public:
|
||||||
|
explicit CudaSparseCsrSortingConversionInfo() : initialized_(false) {}
|
||||||
|
|
||||||
|
CudaSparseCsrSortingConversionInfo(CudaSparseCsrSortingConversionInfo&& rhs)
|
||||||
|
: initialized_(rhs.initialized_), info_(std::move(rhs.info_)) {
|
||||||
|
rhs.initialized_ = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
CudaSparseCsrSortingConversionInfo& operator=(
|
||||||
|
CudaSparseCsrSortingConversionInfo&& rhs) {
|
||||||
|
if (this == &rhs) return *this;
|
||||||
|
Release();
|
||||||
|
initialized_ = rhs.initialized_;
|
||||||
|
info_ = std::move(rhs.info_);
|
||||||
|
rhs.initialized_ = false;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
~CudaSparseCsrSortingConversionInfo() { Release(); }
|
||||||
|
|
||||||
|
// Initializes the underlying info. Will fail on the second call if called
|
||||||
|
// more than once.
|
||||||
|
Status Initialize() {
|
||||||
|
DCHECK(!initialized_);
|
||||||
|
TF_RETURN_IF_CUSPARSE_ERROR(cusparseCreateCsru2csrInfo(&info_));
|
||||||
|
initialized_ = true;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
csru2csrInfo_t& info() {
|
||||||
|
DCHECK(initialized_);
|
||||||
|
return info_;
|
||||||
|
}
|
||||||
|
|
||||||
|
const csru2csrInfo_t& info() const {
|
||||||
|
DCHECK(initialized_);
|
||||||
|
return info_;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
void Release() {
|
||||||
|
if (initialized_) {
|
||||||
|
cusparseDestroyCsru2csrInfo(info_);
|
||||||
|
initialized_ = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool initialized_;
|
||||||
|
csru2csrInfo_t info_;
|
||||||
|
|
||||||
|
TF_DISALLOW_COPY_AND_ASSIGN(CudaSparseCsrSortingConversionInfo);
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA
|
||||||
|
100
tensorflow/core/kernels/sparse/BUILD
Normal file
100
tensorflow/core/kernels/sparse/BUILD
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
# Description: Op kernels for sparse matrix operations.
|
||||||
|
|
||||||
|
load(
|
||||||
|
"//tensorflow:tensorflow.bzl",
|
||||||
|
"tf_cc_test",
|
||||||
|
"tf_kernel_library",
|
||||||
|
)
|
||||||
|
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
||||||
|
|
||||||
|
package(
|
||||||
|
default_visibility = ["//visibility:public"],
|
||||||
|
licenses = ["notice"], # Apache 2.0
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "sparse_matrix",
|
||||||
|
srcs = ["sparse_matrix.cc"],
|
||||||
|
hdrs = ["sparse_matrix.h"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//third_party/eigen3",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_kernel_library(
|
||||||
|
name = "kernels",
|
||||||
|
srcs = [
|
||||||
|
"add_op.cc",
|
||||||
|
"conj_op.cc",
|
||||||
|
"csr_sparse_matrix_to_dense_op.cc",
|
||||||
|
"csr_sparse_matrix_to_sparse_tensor_op.cc",
|
||||||
|
"dense_to_csr_sparse_matrix_op.cc",
|
||||||
|
"kernels.cc",
|
||||||
|
"mat_mul_op.cc",
|
||||||
|
"mul_op.cc",
|
||||||
|
"nnz_op.cc",
|
||||||
|
"softmax_op.cc",
|
||||||
|
"sparse_cholesky_op.cc",
|
||||||
|
"sparse_mat_mul_op.cc",
|
||||||
|
"sparse_matrix_components_op.cc",
|
||||||
|
"sparse_ordering_amd_op.cc",
|
||||||
|
"sparse_tensor_to_csr_sparse_matrix_op.cc",
|
||||||
|
"transpose_op.cc",
|
||||||
|
"zeros_op.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"kernels.h",
|
||||||
|
"transpose_op.h",
|
||||||
|
"zeros_op.h",
|
||||||
|
],
|
||||||
|
gpu_srcs = [
|
||||||
|
"zeros_op.h",
|
||||||
|
"kernels.h",
|
||||||
|
"kernels_gpu.cu.cc",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":sparse_matrix",
|
||||||
|
"//third_party/eigen3",
|
||||||
|
"//tensorflow/core:array_ops_op_lib",
|
||||||
|
"//tensorflow/core:bitwise_ops_op_lib",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:functional_ops_op_lib",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:math_ops_op_lib",
|
||||||
|
"//tensorflow/core:nn_ops_op_lib",
|
||||||
|
"//tensorflow/core:no_op_op_lib",
|
||||||
|
"//tensorflow/core:sendrecv_ops_op_lib",
|
||||||
|
"//tensorflow/core:sparse_csr_matrix_ops_op_lib",
|
||||||
|
"//tensorflow/core:state_ops_op_lib",
|
||||||
|
"//tensorflow/core/kernels:concat_lib",
|
||||||
|
"//tensorflow/core/kernels:constant_op",
|
||||||
|
"//tensorflow/core/kernels:cwise_op",
|
||||||
|
"//tensorflow/core/kernels:dense_update_functor",
|
||||||
|
"//tensorflow/core/kernels:fill_functor",
|
||||||
|
"//tensorflow/core/kernels:gather_nd_op",
|
||||||
|
"//tensorflow/core/kernels:scatter_nd_op",
|
||||||
|
"//tensorflow/core/kernels:slice_op",
|
||||||
|
"//tensorflow/core/kernels:transpose_functor",
|
||||||
|
] + if_cuda([
|
||||||
|
"//tensorflow/core/kernels:cuda_solvers",
|
||||||
|
"//tensorflow/core/kernels:cuda_sparse",
|
||||||
|
]),
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_cc_test(
|
||||||
|
name = "kernels_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = [
|
||||||
|
"kernels_test.cc",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":kernels",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core:testlib",
|
||||||
|
"//third_party/eigen3",
|
||||||
|
],
|
||||||
|
)
|
342
tensorflow/core/kernels/sparse/add_op.cc
Normal file
342
tensorflow/core/kernels/sparse/add_op.cc
Normal file
@ -0,0 +1,342 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#define EIGEN_USE_THREADS
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
#define EIGEN_USE_GPU
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
|
#include "tensorflow/core/framework/op.h"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_util.h"
|
||||||
|
#include "tensorflow/core/framework/variant_op_registry.h"
|
||||||
|
#include "tensorflow/core/kernels/dense_update_functor.h"
|
||||||
|
#include "tensorflow/core/kernels/sparse/kernels.h"
|
||||||
|
#include "tensorflow/core/kernels/sparse/sparse_matrix.h"
|
||||||
|
#include "tensorflow/core/kernels/fill_functor.h"
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
#include "tensorflow/core/kernels/cuda_solvers.h"
|
||||||
|
#include "tensorflow/core/kernels/cuda_sparse.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||||
|
typedef Eigen::GpuDevice GPUDevice;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
template <typename Device, typename T>
|
||||||
|
class CSRSparseMatrixAddFunctor {
|
||||||
|
public:
|
||||||
|
explicit CSRSparseMatrixAddFunctor(OpKernelContext* ctx, const T alpha,
|
||||||
|
const T beta)
|
||||||
|
: ctx_(ctx), alpha_(alpha), beta_(beta) {}
|
||||||
|
|
||||||
|
Status operator()(const CSRSparseMatrix& a, const CSRSparseMatrix& b,
|
||||||
|
CSRSparseMatrix* c) {
|
||||||
|
TensorShape a_tensor_shape;
|
||||||
|
TensorShape b_tensor_shape;
|
||||||
|
TF_RETURN_IF_ERROR(TensorShapeUtils::MakeShape(a.dense_shape().vec<int64>(),
|
||||||
|
&a_tensor_shape));
|
||||||
|
TF_RETURN_IF_ERROR(TensorShapeUtils::MakeShape(b.dense_shape().vec<int64>(),
|
||||||
|
&b_tensor_shape));
|
||||||
|
|
||||||
|
if (a_tensor_shape.dims() == 3) {
|
||||||
|
if ((a_tensor_shape.dims() != b_tensor_shape.dims()) ||
|
||||||
|
(a_tensor_shape.dim_size(0) != b_tensor_shape.dim_size(0))) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Incompatible shapes of a and b, a.shape == ",
|
||||||
|
a_tensor_shape.DebugString(),
|
||||||
|
", b.shape == ", b_tensor_shape.DebugString());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const int rank = a_tensor_shape.dims();
|
||||||
|
if ((a_tensor_shape.dim_size(rank - 2) !=
|
||||||
|
b_tensor_shape.dim_size(rank - 2)) ||
|
||||||
|
(a_tensor_shape.dim_size(rank - 1) !=
|
||||||
|
b_tensor_shape.dim_size(rank - 1))) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Incompatible shapes of a and b, a.shape == ",
|
||||||
|
a_tensor_shape.DebugString(),
|
||||||
|
", b.shape == ", b_tensor_shape.DebugString());
|
||||||
|
}
|
||||||
|
|
||||||
|
const int batch_size = a.batch_size();
|
||||||
|
|
||||||
|
// TODO(ebrevdo): Add support for broadcasting at least in the
|
||||||
|
// batch dimension.
|
||||||
|
auto a_dense_shape = a.dense_shape().vec<int64>();
|
||||||
|
auto b_dense_shape = b.dense_shape().vec<int64>();
|
||||||
|
Tensor c_dense_shape_t = a.dense_shape();
|
||||||
|
|
||||||
|
const int64 rows = a_dense_shape((rank == 2) ? 0 : 1);
|
||||||
|
|
||||||
|
functor::CSRSparseMatrixAdd<Device, T> csr_geam(ctx_, alpha_, beta_);
|
||||||
|
TF_RETURN_IF_ERROR(csr_geam.Initialize());
|
||||||
|
|
||||||
|
Tensor c_batch_ptr_t(cpu_allocator(), DT_INT32,
|
||||||
|
TensorShape({batch_size + 1}));
|
||||||
|
auto c_batch_ptr = c_batch_ptr_t.vec<int32>();
|
||||||
|
c_batch_ptr(0) = 0;
|
||||||
|
|
||||||
|
Tensor c_row_ptr_t;
|
||||||
|
TF_RETURN_IF_ERROR(ctx_->allocate_temp(
|
||||||
|
DT_INT32, TensorShape({batch_size * (rows + 1)}), &c_row_ptr_t));
|
||||||
|
auto c_row_ptr = c_row_ptr_t.vec<int32>();
|
||||||
|
|
||||||
|
// Set the output row pointers to zero, in case we hit any empty
|
||||||
|
// combinations of rows in a and b.
|
||||||
|
functor::SetZeroFunctor<Device, int32> set_zero;
|
||||||
|
const Device& d = ctx_->eigen_device<Device>();
|
||||||
|
set_zero(d, c_row_ptr_t.flat<int32>());
|
||||||
|
|
||||||
|
for (int i = 0; i < batch_size; ++i) {
|
||||||
|
// Calculate output sizes for all minibatch entries.
|
||||||
|
// Store in c_batch_ptr and update c_row_ptrs.
|
||||||
|
if (a.nnz(i) == 0 && b.nnz(i) == 0) {
|
||||||
|
c_batch_ptr(i + 1) = c_batch_ptr(i);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
ConstCSRComponent<T> a_comp{a.row_pointers_vec(i), a.col_indices_vec(i),
|
||||||
|
a.values_vec<T>(i), a_dense_shape};
|
||||||
|
ConstCSRComponent<T> b_comp{b.row_pointers_vec(i), b.col_indices_vec(i),
|
||||||
|
b.values_vec<T>(i), b_dense_shape};
|
||||||
|
TTypes<int32>::UnalignedVec c_row_ptr_i(&c_row_ptr(i * (rows + 1)),
|
||||||
|
rows + 1);
|
||||||
|
int c_nnz_i;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
csr_geam.GetOutputStructure(a_comp, b_comp, c_row_ptr_i, &c_nnz_i));
|
||||||
|
c_batch_ptr(i + 1) = c_batch_ptr(i) + c_nnz_i;
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor c_col_ind_t;
|
||||||
|
Tensor c_values_t;
|
||||||
|
|
||||||
|
const int total_nnz = c_batch_ptr(batch_size);
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
ctx_->allocate_temp(DT_INT32, TensorShape({total_nnz}), &c_col_ind_t));
|
||||||
|
TF_RETURN_IF_ERROR(ctx_->allocate_temp(
|
||||||
|
DataTypeToEnum<T>::value, TensorShape({total_nnz}), &c_values_t));
|
||||||
|
TF_RETURN_IF_ERROR(CSRSparseMatrix::CreateCSRSparseMatrix(
|
||||||
|
DataTypeToEnum<T>::value, c_dense_shape_t, c_batch_ptr_t, c_row_ptr_t,
|
||||||
|
c_col_ind_t, c_values_t, c));
|
||||||
|
|
||||||
|
for (int i = 0; i < batch_size; ++i) {
|
||||||
|
if (a.nnz(i) == 0 && b.nnz(i) == 0) {
|
||||||
|
// Setting of c_row_pointers_vec(i) == 0 is already done.
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
ConstCSRComponent<T> a_comp{a.row_pointers_vec(i), a.col_indices_vec(i),
|
||||||
|
a.values_vec<T>(i), a_dense_shape};
|
||||||
|
ConstCSRComponent<T> b_comp{b.row_pointers_vec(i), b.col_indices_vec(i),
|
||||||
|
b.values_vec<T>(i), b_dense_shape};
|
||||||
|
CSRComponent<T> c_comp{c->row_pointers_vec(i), c->col_indices_vec(i),
|
||||||
|
c->values_vec<T>(i), c_dense_shape_t.vec<int64>()};
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(csr_geam.Compute(a_comp, b_comp, &c_comp));
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
OpKernelContext* ctx_;
|
||||||
|
const T alpha_;
|
||||||
|
const T beta_;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Device, typename T>
|
||||||
|
class CSRSparseMatrixSumFunctor : public CSRSparseMatrixAddFunctor<Device, T> {
|
||||||
|
public:
|
||||||
|
// Same as above, but with alpha = beta = 1.0, so C = 1.0 * A + 1.0 * B.
|
||||||
|
explicit CSRSparseMatrixSumFunctor(OpKernelContext* ctx)
|
||||||
|
: CSRSparseMatrixAddFunctor<Device, T>(ctx, 1, 1) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
template <typename Device, typename T>
|
||||||
|
class CSRAddOp : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit CSRAddOp(OpKernelConstruction* c) : OpKernel(c) {}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* ctx) final {
|
||||||
|
const CSRSparseMatrix* a_matrix;
|
||||||
|
const CSRSparseMatrix* b_matrix;
|
||||||
|
OP_REQUIRES_OK(ctx, ExtractVariantFromInput(ctx, 0, &a_matrix));
|
||||||
|
OP_REQUIRES_OK(ctx, ExtractVariantFromInput(ctx, 1, &b_matrix));
|
||||||
|
|
||||||
|
OP_REQUIRES(
|
||||||
|
ctx, a_matrix->dtype() == DataTypeToEnum<T>::value,
|
||||||
|
errors::InvalidArgument("dtype of a is not equal to 'type': ",
|
||||||
|
DataTypeString(a_matrix->dtype()), " vs. ",
|
||||||
|
DataTypeString(DataTypeToEnum<T>::value)));
|
||||||
|
OP_REQUIRES(
|
||||||
|
ctx, b_matrix->dtype() == DataTypeToEnum<T>::value,
|
||||||
|
errors::InvalidArgument("dtype of b is not equal to 'type': ",
|
||||||
|
DataTypeString(b_matrix->dtype()), " vs. ",
|
||||||
|
DataTypeString(DataTypeToEnum<T>::value)));
|
||||||
|
|
||||||
|
const Tensor& alpha_t = ctx->input(2);
|
||||||
|
const Tensor& beta_t = ctx->input(3);
|
||||||
|
OP_REQUIRES(
|
||||||
|
ctx, TensorShapeUtils::IsScalar(alpha_t.shape()),
|
||||||
|
errors::InvalidArgument("Expected alpha to be a scalar, saw shape: ",
|
||||||
|
alpha_t.shape().DebugString()));
|
||||||
|
OP_REQUIRES(
|
||||||
|
ctx, TensorShapeUtils::IsScalar(beta_t.shape()),
|
||||||
|
errors::InvalidArgument("Expected beta to be a scalar, saw shape: ",
|
||||||
|
beta_t.shape().DebugString()));
|
||||||
|
|
||||||
|
const T host_alpha = alpha_t.scalar<T>()();
|
||||||
|
const T host_beta = beta_t.scalar<T>()();
|
||||||
|
|
||||||
|
Tensor c_t(cpu_allocator(), DT_VARIANT, TensorShape({}));
|
||||||
|
CSRSparseMatrix c_matrix;
|
||||||
|
CSRSparseMatrixAddFunctor<Device, T> add_functor(ctx, host_alpha,
|
||||||
|
host_beta);
|
||||||
|
OP_REQUIRES_OK(ctx, add_functor(*a_matrix, *b_matrix, &c_matrix));
|
||||||
|
c_t.scalar<Variant>()() = std::move(c_matrix);
|
||||||
|
ctx->set_output(0, c_t);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
#define REGISTER(DEV, T) \
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("SparseMatrixAdd") \
|
||||||
|
.Device(DEVICE_##DEV) \
|
||||||
|
.TypeConstraint<T>("T") \
|
||||||
|
.HostMemory("alpha") \
|
||||||
|
.HostMemory("beta"), \
|
||||||
|
CSRAddOp<DEV##Device, T>);
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
|
||||||
|
#define REGISTER_GPU(T) REGISTER(GPU, T)
|
||||||
|
|
||||||
|
REGISTER_GPU(float)
|
||||||
|
REGISTER_GPU(double)
|
||||||
|
REGISTER_GPU(complex64)
|
||||||
|
REGISTER_GPU(complex128)
|
||||||
|
|
||||||
|
#undef REGISTER_GPU
|
||||||
|
|
||||||
|
REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(
|
||||||
|
ADD_VARIANT_BINARY_OP, DEVICE_GPU, CSRSparseMatrix,
|
||||||
|
(CSRSparseMatrixBinaryHelper<GPUDevice, CSRSparseMatrixSumFunctor>));
|
||||||
|
|
||||||
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
|
#undef REGISTER
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
namespace functor {
|
||||||
|
template <typename T>
|
||||||
|
struct CSRSparseMatrixAdd<GPUDevice, T>
|
||||||
|
: public CSRStructureModifyingFunctor<GPUDevice, T> {
|
||||||
|
explicit CSRSparseMatrixAdd(OpKernelContext* ctx, const T alpha, const T beta)
|
||||||
|
: ctx_(ctx),
|
||||||
|
cuda_sparse_(ctx),
|
||||||
|
alpha_(alpha),
|
||||||
|
beta_(beta),
|
||||||
|
initialized_(false) {}
|
||||||
|
|
||||||
|
Status Initialize() {
|
||||||
|
TF_RETURN_IF_ERROR(cuda_sparse_.Initialize());
|
||||||
|
TF_RETURN_IF_ERROR(descrA_.Initialize());
|
||||||
|
TF_RETURN_IF_ERROR(descrB_.Initialize());
|
||||||
|
TF_RETURN_IF_ERROR(descrC_.Initialize());
|
||||||
|
initialized_ = true;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GetOutputStructure(const ConstCSRComponent<T>& a,
|
||||||
|
const ConstCSRComponent<T>& b,
|
||||||
|
TTypes<int32>::UnalignedVec c_row_ptr,
|
||||||
|
int* output_nnz) {
|
||||||
|
DCHECK(initialized_);
|
||||||
|
|
||||||
|
const int m = a.row_ptr.size() - 1;
|
||||||
|
DCHECK_EQ(m, b.row_ptr.size() - 1);
|
||||||
|
const int row_dim = a.dense_shape_host.size() == 2 ? 0 : 1;
|
||||||
|
DCHECK_EQ(m, a.dense_shape_host(row_dim));
|
||||||
|
DCHECK_EQ(m, b.dense_shape_host(row_dim));
|
||||||
|
const int nnzA = a.col_ind.size();
|
||||||
|
const int nnzB = b.col_ind.size();
|
||||||
|
*output_nnz = -1;
|
||||||
|
|
||||||
|
const int n = a.dense_shape_host(row_dim + 1);
|
||||||
|
DCHECK_EQ(n, b.dense_shape_host(row_dim + 1));
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(cuda_sparse_.CsrgeamNnz(
|
||||||
|
m, n, descrA_.descr(), nnzA, a.row_ptr.data(), a.col_ind.data(),
|
||||||
|
descrB_.descr(), nnzB, b.row_ptr.data(), b.col_ind.data(),
|
||||||
|
descrC_.descr(), c_row_ptr.data(), output_nnz));
|
||||||
|
|
||||||
|
if (*output_nnz < 0) {
|
||||||
|
return errors::Internal(
|
||||||
|
"CSRAdd: CsrgeamNnz returned nnzTotalDevHostPtr < 0: ", *output_nnz);
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status Compute(const ConstCSRComponent<T>& a, const ConstCSRComponent<T>& b,
|
||||||
|
CSRComponent<T>* c) {
|
||||||
|
DCHECK(initialized_);
|
||||||
|
|
||||||
|
const int m = a.row_ptr.size() - 1;
|
||||||
|
DCHECK_EQ(m, b.row_ptr.size() - 1);
|
||||||
|
const int row_dim = a.dense_shape_host.size() == 2 ? 0 : 1;
|
||||||
|
DCHECK_EQ(m, a.dense_shape_host(row_dim));
|
||||||
|
DCHECK_EQ(m, b.dense_shape_host(row_dim));
|
||||||
|
const int nnzA = a.col_ind.size();
|
||||||
|
const int nnzB = b.col_ind.size();
|
||||||
|
|
||||||
|
const int n = a.dense_shape_host(row_dim + 1);
|
||||||
|
DCHECK_EQ(n, b.dense_shape_host(row_dim + 1));
|
||||||
|
|
||||||
|
// Adding alpha * a + beta * b.
|
||||||
|
TF_RETURN_IF_ERROR(cuda_sparse_.Csrgeam(
|
||||||
|
m, n, &alpha_, descrA_.descr(), nnzA, a.values.data(), a.row_ptr.data(),
|
||||||
|
a.col_ind.data(), &beta_, descrB_.descr(), nnzB, b.values.data(),
|
||||||
|
b.row_ptr.data(), b.col_ind.data(), descrC_.descr(), c->values.data(),
|
||||||
|
c->row_ptr.data(), c->col_ind.data()));
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
OpKernelContext* ctx_;
|
||||||
|
CudaSparse cuda_sparse_;
|
||||||
|
CudaSparseMatrixDescriptor descrA_;
|
||||||
|
CudaSparseMatrixDescriptor descrB_;
|
||||||
|
CudaSparseMatrixDescriptor descrC_;
|
||||||
|
const T alpha_;
|
||||||
|
const T beta_;
|
||||||
|
bool initialized_;
|
||||||
|
|
||||||
|
TF_DISALLOW_COPY_AND_ASSIGN(CSRSparseMatrixAdd);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace functor
|
||||||
|
|
||||||
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
98
tensorflow/core/kernels/sparse/conj_op.cc
Normal file
98
tensorflow/core/kernels/sparse/conj_op.cc
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#define EIGEN_USE_THREADS
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
#define EIGEN_USE_GPU
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
|
#include "tensorflow/core/framework/op.h"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_util.h"
|
||||||
|
#include "tensorflow/core/framework/variant_op_registry.h"
|
||||||
|
#include "tensorflow/core/kernels/cwise_ops.h"
|
||||||
|
#include "tensorflow/core/kernels/sparse/kernels.h"
|
||||||
|
#include "tensorflow/core/kernels/sparse/sparse_matrix.h"
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
#include "tensorflow/core/kernels/cuda_solvers.h"
|
||||||
|
#include "tensorflow/core/kernels/cuda_sparse.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||||
|
typedef Eigen::GpuDevice GPUDevice;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
template <typename Device, typename T>
|
||||||
|
class CSRSparseMatrixConjFunctor {
|
||||||
|
public:
|
||||||
|
explicit CSRSparseMatrixConjFunctor(OpKernelContext* ctx) : ctx_(ctx) {}
|
||||||
|
|
||||||
|
Status operator()(const CSRSparseMatrix& a, CSRSparseMatrix* b) {
|
||||||
|
const int total_nnz = a.total_nnz();
|
||||||
|
Tensor b_values_t;
|
||||||
|
TF_RETURN_IF_ERROR(ctx_->allocate_temp(
|
||||||
|
DataTypeToEnum<T>::value, TensorShape({total_nnz}), &b_values_t));
|
||||||
|
TF_RETURN_IF_ERROR(CSRSparseMatrix::CreateCSRSparseMatrix(
|
||||||
|
DataTypeToEnum<T>::value, a.dense_shape(), a.batch_pointers(),
|
||||||
|
a.row_pointers(), a.col_indices(), b_values_t, b));
|
||||||
|
|
||||||
|
const Device& d = ctx_->eigen_device<Device>();
|
||||||
|
functor::UnaryFunctor<Device, functor::conj<T>> func;
|
||||||
|
func(d, b->values().flat<T>() /*out*/, a.values().flat<T>() /*in*/);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
OpKernelContext* ctx_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Partial specialization for real types where conjugation is a noop.
|
||||||
|
#define NOOP_CONJ_FUNCTOR(T) \
|
||||||
|
template <typename Device> \
|
||||||
|
class CSRSparseMatrixConjFunctor<Device, T> { \
|
||||||
|
public: \
|
||||||
|
explicit CSRSparseMatrixConjFunctor(OpKernelContext* ctx) {} \
|
||||||
|
Status operator()(const CSRSparseMatrix& a, CSRSparseMatrix* b) { \
|
||||||
|
TF_RETURN_IF_ERROR(CSRSparseMatrix::CreateCSRSparseMatrix( \
|
||||||
|
DataTypeToEnum<T>::value, a.dense_shape(), a.batch_pointers(), \
|
||||||
|
a.row_pointers(), a.col_indices(), a.values(), b)); \
|
||||||
|
return Status::OK(); \
|
||||||
|
} \
|
||||||
|
};
|
||||||
|
|
||||||
|
NOOP_CONJ_FUNCTOR(float);
|
||||||
|
NOOP_CONJ_FUNCTOR(double);
|
||||||
|
|
||||||
|
#undef NOOP_CONJ_FUNCTOR
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
|
||||||
|
REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(
|
||||||
|
CONJ_VARIANT_UNARY_OP, DEVICE_GPU, CSRSparseMatrix,
|
||||||
|
(CSRSparseMatrixUnaryHelper<GPUDevice, CSRSparseMatrixConjFunctor>));
|
||||||
|
|
||||||
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
267
tensorflow/core/kernels/sparse/csr_sparse_matrix_to_dense_op.cc
Normal file
267
tensorflow/core/kernels/sparse/csr_sparse_matrix_to_dense_op.cc
Normal file
@ -0,0 +1,267 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#define EIGEN_USE_THREADS
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
#define EIGEN_USE_GPU
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
|
#include "tensorflow/core/framework/allocator.h"
|
||||||
|
#include "tensorflow/core/framework/op.h"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
|
#include "tensorflow/core/framework/variant_op_registry.h"
|
||||||
|
#include "tensorflow/core/kernels/concat_lib.h"
|
||||||
|
#include "tensorflow/core/kernels/fill_functor.h"
|
||||||
|
#include "tensorflow/core/kernels/scatter_nd_op.h"
|
||||||
|
#include "tensorflow/core/kernels/sparse/kernels.h"
|
||||||
|
#include "tensorflow/core/kernels/sparse/sparse_matrix.h"
|
||||||
|
#include "tensorflow/core/util/work_sharder.h"
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
#include "tensorflow/core/kernels/cuda_solvers.h"
|
||||||
|
#include "tensorflow/core/kernels/cuda_sparse.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||||
|
typedef Eigen::GpuDevice GPUDevice;
|
||||||
|
|
||||||
|
// Op to convert a (batched) CSR SparseMatrix to dense Tensors on the CPU.
|
||||||
|
// The resulting Tensor will have rank 2 or (if batched) 3. Missing values in
|
||||||
|
// the CSR SparseMatrix are interpreted as zeros in the dense Tensor.
|
||||||
|
template <typename Device, typename T>
|
||||||
|
class CSRSparseMatrixToDenseCPUOp : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit CSRSparseMatrixToDenseCPUOp(OpKernelConstruction* c) : OpKernel(c) {}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* context) override {
|
||||||
|
const CSRSparseMatrix* csr_sparse_matrix;
|
||||||
|
OP_REQUIRES_OK(context,
|
||||||
|
ExtractVariantFromInput(context, 0, &csr_sparse_matrix));
|
||||||
|
|
||||||
|
OP_REQUIRES(
|
||||||
|
context, csr_sparse_matrix->dtype() == DataTypeToEnum<T>::value,
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"Asked for a CSRSparseMatrix of type ",
|
||||||
|
DataTypeString(DataTypeToEnum<T>::value),
|
||||||
|
" but saw dtype: ", DataTypeString(csr_sparse_matrix->dtype())));
|
||||||
|
|
||||||
|
const Tensor& dense_shape_t = csr_sparse_matrix->dense_shape();
|
||||||
|
const int rank = dense_shape_t.dim_size(0);
|
||||||
|
OP_REQUIRES(context, rank == 2 || rank == 3,
|
||||||
|
errors::InvalidArgument("sparse matrix must have rank 2 or 3; ",
|
||||||
|
"but dense_shape has size ", rank));
|
||||||
|
|
||||||
|
auto dense_shape = dense_shape_t.vec<int64>();
|
||||||
|
const int64 num_rows = dense_shape((rank == 2) ? 0 : 1);
|
||||||
|
const int64 num_cols = dense_shape((rank == 2) ? 1 : 2);
|
||||||
|
|
||||||
|
auto batch_ptrs = csr_sparse_matrix->batch_pointers().vec<int32>();
|
||||||
|
auto row_ptr = csr_sparse_matrix->row_pointers().vec<int32>();
|
||||||
|
auto col_ind = csr_sparse_matrix->col_indices().vec<int32>();
|
||||||
|
auto values = csr_sparse_matrix->values().vec<T>();
|
||||||
|
|
||||||
|
TensorShape dense_tensor_shape;
|
||||||
|
OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(dense_shape.data(),
|
||||||
|
dense_shape.size(),
|
||||||
|
&dense_tensor_shape));
|
||||||
|
Tensor dense_t(cpu_allocator(), DataTypeToEnum<T>::value,
|
||||||
|
dense_tensor_shape);
|
||||||
|
|
||||||
|
// Fill the dense tensor with zeros.
|
||||||
|
functor::SetZeroFunctor<Device, T> set_zero;
|
||||||
|
set_zero(context->eigen_device<Device>(), dense_t.flat<T>());
|
||||||
|
|
||||||
|
auto dense_ptr = dense_t.flat<T>().data();
|
||||||
|
|
||||||
|
// Process the individual batches in parallel using a threadpool.
|
||||||
|
auto shard = [&](int64 batch_begin, int64 batch_end) {
|
||||||
|
for (int64 batch_idx = batch_begin; batch_idx < batch_end; ++batch_idx) {
|
||||||
|
const int64 csr_batch_offset = batch_ptrs(batch_idx);
|
||||||
|
const int64 dense_batch_offset = batch_idx * num_rows * num_cols;
|
||||||
|
|
||||||
|
for (int row_idx = 0; row_idx < num_rows; ++row_idx) {
|
||||||
|
const int64 row_offset = batch_idx * (num_rows + 1) + row_idx;
|
||||||
|
const int64 col_begin = row_ptr(row_offset);
|
||||||
|
const int64 col_end = row_ptr(row_offset + 1);
|
||||||
|
for (int64 i = col_begin; i < col_end; ++i) {
|
||||||
|
const int64 col_idx = col_ind(csr_batch_offset + i);
|
||||||
|
dense_ptr[dense_batch_offset + (row_idx * num_cols) + col_idx] =
|
||||||
|
values(csr_batch_offset + i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
const int batch_size = csr_sparse_matrix->batch_size();
|
||||||
|
auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
|
||||||
|
Shard(worker_threads.num_threads, worker_threads.workers, batch_size,
|
||||||
|
csr_sparse_matrix->total_nnz() / batch_size /* cost per unit */,
|
||||||
|
shard);
|
||||||
|
|
||||||
|
context->set_output(0, dense_t);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Device, typename T>
|
||||||
|
class CSRSparseMatrixToDenseGPUOp : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit CSRSparseMatrixToDenseGPUOp(OpKernelConstruction* c) : OpKernel(c) {}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* c) final {
|
||||||
|
const CSRSparseMatrix* csr_sparse_matrix;
|
||||||
|
OP_REQUIRES_OK(c, ExtractVariantFromInput(c, 0, &csr_sparse_matrix));
|
||||||
|
|
||||||
|
OP_REQUIRES(
|
||||||
|
c, csr_sparse_matrix->dtype() == DataTypeToEnum<T>::value,
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"Asked for a CSRSparseMatrix of type ",
|
||||||
|
DataTypeString(DataTypeToEnum<T>::value),
|
||||||
|
" but saw dtype: ", DataTypeString(csr_sparse_matrix->dtype())));
|
||||||
|
|
||||||
|
const Tensor& dense_shape_t = csr_sparse_matrix->dense_shape();
|
||||||
|
const int rank = dense_shape_t.dim_size(0);
|
||||||
|
OP_REQUIRES(c, rank == 2 || rank == 3,
|
||||||
|
errors::InvalidArgument("sparse matrix must have rank 2 or 3; ",
|
||||||
|
"but dense_shape has size ", rank));
|
||||||
|
|
||||||
|
const int batch_size = csr_sparse_matrix->batch_size();
|
||||||
|
const int64 total_nnz = csr_sparse_matrix->total_nnz();
|
||||||
|
|
||||||
|
auto dense_shape = dense_shape_t.vec<int64>();
|
||||||
|
const int64 rows = dense_shape((rank == 2) ? 0 : 1);
|
||||||
|
|
||||||
|
Tensor indices_t;
|
||||||
|
OP_REQUIRES_OK(c, c->allocate_temp(DT_INT64, TensorShape({total_nnz, rank}),
|
||||||
|
&indices_t));
|
||||||
|
|
||||||
|
Tensor values_t;
|
||||||
|
OP_REQUIRES_OK(c, c->allocate_temp(DataTypeToEnum<T>::value,
|
||||||
|
TensorShape({total_nnz}), &values_t));
|
||||||
|
|
||||||
|
functor::CSRSparseMatrixToCOOSparseMatrix<Device> csr_to_coo;
|
||||||
|
auto indices = indices_t.matrix<int64>();
|
||||||
|
|
||||||
|
auto csr_row_ptr = csr_sparse_matrix->row_pointers().vec<int32>();
|
||||||
|
auto coo_col_ind = csr_sparse_matrix->col_indices().vec<int32>();
|
||||||
|
auto batch_ptrs = csr_sparse_matrix->batch_pointers().vec<int32>();
|
||||||
|
|
||||||
|
Tensor coo_row_ind_t;
|
||||||
|
OP_REQUIRES_OK(c, c->allocate_temp(DT_INT32, TensorShape({total_nnz}),
|
||||||
|
&coo_row_ind_t));
|
||||||
|
auto coo_row_ind = coo_row_ind_t.vec<int32>();
|
||||||
|
|
||||||
|
// TODO(ebrevdo): just write a custom kernel that converts from
|
||||||
|
// csr to dense.
|
||||||
|
for (int i = 0; i < batch_size; ++i) {
|
||||||
|
const int nnz_i = csr_sparse_matrix->nnz(i);
|
||||||
|
if (nnz_i == 0) {
|
||||||
|
// No copying required. Avoid failure case below.
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
const TTypes<int32>::UnalignedConstVec csr_row_ptr_i(
|
||||||
|
&csr_row_ptr((rows + 1) * i), rows + 1);
|
||||||
|
const TTypes<int32>::UnalignedVec coo_row_ind_i(
|
||||||
|
&coo_row_ind(csr_sparse_matrix->batch_offset(i)), nnz_i);
|
||||||
|
OP_REQUIRES_OK(c, csr_to_coo(c, csr_row_ptr_i, coo_row_ind_i));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (total_nnz > 0) {
|
||||||
|
functor::COOSparseMatrixToSparseTensor<Device> coo_to_st;
|
||||||
|
OP_REQUIRES_OK(c, coo_to_st(c, dense_shape, batch_ptrs, coo_row_ind,
|
||||||
|
coo_col_ind, indices));
|
||||||
|
}
|
||||||
|
|
||||||
|
values_t = csr_sparse_matrix->values();
|
||||||
|
|
||||||
|
Tensor dense_t;
|
||||||
|
TensorShape dense_tensor_shape;
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
c, TensorShapeUtils::MakeShape(dense_shape.data(), dense_shape.size(),
|
||||||
|
&dense_tensor_shape));
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
c,
|
||||||
|
functor::DoScatterNd<Device, T, int64, scatter_nd_op::UpdateOp::ASSIGN>(
|
||||||
|
c, indices_t, values_t, dense_tensor_shape, &dense_t,
|
||||||
|
true /*allocate*/));
|
||||||
|
c->set_output(0, dense_t);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
#define REGISTER_GPU(T) \
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("CSRSparseMatrixToDense") \
|
||||||
|
.Device(DEVICE_GPU) \
|
||||||
|
.TypeConstraint<T>("type"), \
|
||||||
|
CSRSparseMatrixToDenseGPUOp<GPUDevice, T>);
|
||||||
|
|
||||||
|
#define REGISTER_CPU(T) \
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("CSRSparseMatrixToDense") \
|
||||||
|
.Device(DEVICE_CPU) \
|
||||||
|
.TypeConstraint<T>("type"), \
|
||||||
|
CSRSparseMatrixToDenseCPUOp<CPUDevice, T>);
|
||||||
|
REGISTER_CPU(float)
|
||||||
|
REGISTER_CPU(double)
|
||||||
|
REGISTER_CPU(complex64)
|
||||||
|
REGISTER_CPU(complex128)
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
|
||||||
|
REGISTER_GPU(float)
|
||||||
|
REGISTER_GPU(double)
|
||||||
|
REGISTER_GPU(complex64)
|
||||||
|
REGISTER_GPU(complex128)
|
||||||
|
|
||||||
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
|
#undef REGISTER_CPU
|
||||||
|
#undef REGISTER_GPU
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
|
||||||
|
namespace functor {
|
||||||
|
template <>
|
||||||
|
struct COOSparseMatrixToSparseTensor<GPUDevice> {
|
||||||
|
Status operator()(OpKernelContext* ctx,
|
||||||
|
TTypes<int64>::ConstVec host_dense_shape,
|
||||||
|
TTypes<int>::ConstVec host_batch_ptrs,
|
||||||
|
TTypes<int>::Vec coo_row_ind,
|
||||||
|
TTypes<int>::ConstVec coo_col_ind,
|
||||||
|
TTypes<int64>::Matrix indices);
|
||||||
|
};
|
||||||
|
extern template struct COOSparseMatrixToSparseTensor<GPUDevice>;
|
||||||
|
|
||||||
|
// TODO(ebrevdo): Write a custom batch-friendly impl of this to update
|
||||||
|
// the SparseTensor indices directly.
|
||||||
|
template <>
|
||||||
|
Status CSRSparseMatrixToCOOSparseMatrix<GPUDevice>::operator()(
|
||||||
|
OpKernelContext* c, TTypes<const int>::UnalignedVec csr_row_ptr,
|
||||||
|
TTypes<int>::UnalignedVec coo_row_ind) {
|
||||||
|
CudaSparse cuda_sparse(c);
|
||||||
|
const int nnz = coo_row_ind.size();
|
||||||
|
TF_RETURN_IF_ERROR(cuda_sparse.Initialize());
|
||||||
|
const int m = csr_row_ptr.size() - 1; // rows
|
||||||
|
return cuda_sparse.Csr2coo(csr_row_ptr.data(), nnz, m, coo_row_ind.data());
|
||||||
|
}
|
||||||
|
extern template struct CSRSparseMatrixToCOOSparseMatrix<GPUDevice>;
|
||||||
|
|
||||||
|
} // namespace functor
|
||||||
|
|
||||||
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
@ -0,0 +1,264 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#define EIGEN_USE_THREADS
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
#define EIGEN_USE_GPU
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
|
#include "tensorflow/core/framework/op.h"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
|
#include "tensorflow/core/framework/types.h"
|
||||||
|
#include "tensorflow/core/framework/variant_op_registry.h"
|
||||||
|
#include "tensorflow/core/kernels/concat_lib.h"
|
||||||
|
#include "tensorflow/core/kernels/sparse/kernels.h"
|
||||||
|
#include "tensorflow/core/kernels/sparse/sparse_matrix.h"
|
||||||
|
#include "tensorflow/core/util/work_sharder.h"
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
#include "tensorflow/core/kernels/cuda_solvers.h"
|
||||||
|
#include "tensorflow/core/kernels/cuda_sparse.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using CPUDevice = Eigen::ThreadPoolDevice;
|
||||||
|
using GPUDevice = Eigen::GpuDevice;
|
||||||
|
|
||||||
|
// Validate that CSR SparseMatrix has the expected dtype and rank 2 or 3.
|
||||||
|
Status ValidateCSRSparseMatrix(const CSRSparseMatrix& csr_sparse_matrix,
|
||||||
|
DataType expected_dtype) {
|
||||||
|
if (csr_sparse_matrix.dtype() != expected_dtype) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Expected a CSRSparseMatrix of type ", DataTypeString(expected_dtype),
|
||||||
|
" but saw type: ", DataTypeString(csr_sparse_matrix.dtype()));
|
||||||
|
}
|
||||||
|
const int rank = csr_sparse_matrix.dense_shape().dim_size(0);
|
||||||
|
if (rank != 2 && rank != 3) {
|
||||||
|
return errors::InvalidArgument("CSR SparseMatrix must have rank 2 or 3; ",
|
||||||
|
"but dense_shape has size ", rank);
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
// Op to convert a (batched) CSR SparseMatrix to SparseTensors on the CPU.
|
||||||
|
// The resulting SparseTensor will have the same dense shape and non-zero values
|
||||||
|
// as the CSR SparseMatrix. rank 2 or (if batched) 3. Moreover, the resulting
|
||||||
|
// SparseTensor's indices will be present in the canonical, row-major ordering.
|
||||||
|
template <typename T>
|
||||||
|
class CSRSparseMatrixToSparseTensorCPUOp : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit CSRSparseMatrixToSparseTensorCPUOp(OpKernelConstruction* c)
|
||||||
|
: OpKernel(c) {}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* c) final {
|
||||||
|
const CSRSparseMatrix* csr_sparse_matrix;
|
||||||
|
OP_REQUIRES_OK(c, ExtractVariantFromInput(c, 0, &csr_sparse_matrix));
|
||||||
|
OP_REQUIRES_OK(c, ValidateCSRSparseMatrix(*csr_sparse_matrix,
|
||||||
|
DataTypeToEnum<T>::value));
|
||||||
|
|
||||||
|
// Copy the SparseTensor's dense_shape and values from the CSRSparseMatrix.
|
||||||
|
c->set_output(1, csr_sparse_matrix->values());
|
||||||
|
const Tensor& dense_shape = csr_sparse_matrix->dense_shape();
|
||||||
|
c->set_output(2, dense_shape);
|
||||||
|
|
||||||
|
const int batch_size = csr_sparse_matrix->batch_size();
|
||||||
|
const int64 total_nnz = csr_sparse_matrix->total_nnz();
|
||||||
|
const int rank = csr_sparse_matrix->dense_shape().dim_size(0);
|
||||||
|
auto dense_shape_vec = dense_shape.vec<int64>();
|
||||||
|
const int64 num_rows = dense_shape_vec((rank == 2) ? 0 : 1);
|
||||||
|
|
||||||
|
Tensor* indices;
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
c, c->allocate_output(0, TensorShape({total_nnz, rank}), &indices));
|
||||||
|
auto indices_flat = indices->template flat<int64>();
|
||||||
|
|
||||||
|
auto csr_row_ptr = csr_sparse_matrix->row_pointers().vec<int32>();
|
||||||
|
auto csr_col_ind = csr_sparse_matrix->col_indices().vec<int32>();
|
||||||
|
auto batch_ptrs = csr_sparse_matrix->batch_pointers().vec<int32>();
|
||||||
|
|
||||||
|
// Process the individual batches in parallel using a threadpool.
|
||||||
|
auto shard = [&](int64 batch_begin, int64 batch_end) {
|
||||||
|
for (int64 batch_idx = batch_begin; batch_idx < batch_end; ++batch_idx) {
|
||||||
|
const int64 csr_batch_offset = batch_ptrs(batch_idx);
|
||||||
|
|
||||||
|
for (int row_idx = 0; row_idx < num_rows; ++row_idx) {
|
||||||
|
const int64 row_offset = batch_idx * (num_rows + 1) + row_idx;
|
||||||
|
|
||||||
|
// The column indices of the current row lie in the range:
|
||||||
|
// [csr_row_ptr[row_offset], csr_row_ptr[row_offset + 1])
|
||||||
|
const int64 col_begin = csr_row_ptr(row_offset);
|
||||||
|
const int64 col_end = csr_row_ptr(row_offset + 1);
|
||||||
|
for (int64 i = col_begin; i < col_end; ++i) {
|
||||||
|
const int64 col_idx = csr_col_ind(csr_batch_offset + i);
|
||||||
|
const int64 indices_offset = rank * (csr_batch_offset + i);
|
||||||
|
|
||||||
|
if (rank == 2) {
|
||||||
|
indices_flat(indices_offset) = row_idx;
|
||||||
|
indices_flat(indices_offset + 1) = col_idx;
|
||||||
|
} else { // rank == 3
|
||||||
|
indices_flat(indices_offset) = batch_idx;
|
||||||
|
indices_flat(indices_offset + 1) = row_idx;
|
||||||
|
indices_flat(indices_offset + 2) = col_idx;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
auto worker_threads = *(c->device()->tensorflow_cpu_worker_threads());
|
||||||
|
// TODO(anudhyan): Estimate the cost per unit based on Eigen::TensorOpCost
|
||||||
|
// units and scale based on benchmarks.
|
||||||
|
Shard(worker_threads.num_threads, worker_threads.workers, batch_size,
|
||||||
|
csr_sparse_matrix->total_nnz() / batch_size /* cost per unit */,
|
||||||
|
shard);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Device, typename T>
|
||||||
|
class CSRSparseMatrixToSparseTensorGPUOp : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit CSRSparseMatrixToSparseTensorGPUOp(OpKernelConstruction* c)
|
||||||
|
: OpKernel(c) {}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* c) final {
|
||||||
|
const CSRSparseMatrix* csr_sparse_matrix;
|
||||||
|
OP_REQUIRES_OK(c, ExtractVariantFromInput(c, 0, &csr_sparse_matrix));
|
||||||
|
OP_REQUIRES_OK(c, ValidateCSRSparseMatrix(*csr_sparse_matrix,
|
||||||
|
DataTypeToEnum<T>::value));
|
||||||
|
|
||||||
|
const Tensor& dense_shape_t = csr_sparse_matrix->dense_shape();
|
||||||
|
c->set_output(2, dense_shape_t);
|
||||||
|
const int rank = dense_shape_t.dim_size(0);
|
||||||
|
const int batch_size = csr_sparse_matrix->batch_size();
|
||||||
|
const int64 total_nnz = csr_sparse_matrix->total_nnz();
|
||||||
|
|
||||||
|
auto dense_shape = dense_shape_t.vec<int64>();
|
||||||
|
const int64 rows = dense_shape((rank == 2) ? 0 : 1);
|
||||||
|
|
||||||
|
Tensor* indices_t;
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
c, c->allocate_output(0, TensorShape({total_nnz, rank}), &indices_t));
|
||||||
|
|
||||||
|
Tensor* values_t;
|
||||||
|
OP_REQUIRES_OK(c,
|
||||||
|
c->allocate_output(1, TensorShape({total_nnz}), &values_t));
|
||||||
|
|
||||||
|
functor::CSRSparseMatrixToCOOSparseMatrix<Device> csr_to_coo;
|
||||||
|
auto indices = indices_t->matrix<int64>();
|
||||||
|
|
||||||
|
auto csr_row_ptr = csr_sparse_matrix->row_pointers().vec<int32>();
|
||||||
|
auto coo_col_ind = csr_sparse_matrix->col_indices().vec<int32>();
|
||||||
|
auto batch_ptrs = csr_sparse_matrix->batch_pointers().vec<int32>();
|
||||||
|
|
||||||
|
Tensor coo_row_ind_t;
|
||||||
|
OP_REQUIRES_OK(c, c->allocate_temp(DT_INT32, TensorShape({total_nnz}),
|
||||||
|
&coo_row_ind_t));
|
||||||
|
auto coo_row_ind = coo_row_ind_t.vec<int32>();
|
||||||
|
|
||||||
|
// TODO(ebrevdo): Convert to one or two single kernel calls,
|
||||||
|
// where the kernels are batch-friendly.
|
||||||
|
for (int i = 0; i < batch_size; ++i) {
|
||||||
|
const int nnz_i = csr_sparse_matrix->nnz(i);
|
||||||
|
if (nnz_i == 0) {
|
||||||
|
// No copying required. Avoid failure case below.
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
const TTypes<int32>::UnalignedConstVec csr_row_ptr_i(
|
||||||
|
&csr_row_ptr((rows + 1) * i), rows + 1);
|
||||||
|
const TTypes<int32>::UnalignedVec coo_row_ind_i(
|
||||||
|
&coo_row_ind(csr_sparse_matrix->batch_offset(i)), nnz_i);
|
||||||
|
OP_REQUIRES_OK(c, csr_to_coo(c, csr_row_ptr_i, coo_row_ind_i));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (total_nnz > 0) {
|
||||||
|
functor::COOSparseMatrixToSparseTensor<Device> coo_to_st;
|
||||||
|
OP_REQUIRES_OK(c, coo_to_st(c, dense_shape, batch_ptrs, coo_row_ind,
|
||||||
|
coo_col_ind, indices));
|
||||||
|
}
|
||||||
|
|
||||||
|
*values_t = csr_sparse_matrix->values();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
#define REGISTER_GPU(T) \
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("CSRSparseMatrixToSparseTensor") \
|
||||||
|
.Device(DEVICE_GPU) \
|
||||||
|
.TypeConstraint<T>("type") \
|
||||||
|
.HostMemory("dense_shape"), \
|
||||||
|
CSRSparseMatrixToSparseTensorGPUOp<GPUDevice, T>);
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
|
||||||
|
REGISTER_GPU(float)
|
||||||
|
REGISTER_GPU(double)
|
||||||
|
REGISTER_GPU(complex64)
|
||||||
|
REGISTER_GPU(complex128)
|
||||||
|
|
||||||
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
|
#undef REGISTER_GPU
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
|
||||||
|
namespace functor {
|
||||||
|
template <>
|
||||||
|
struct COOSparseMatrixToSparseTensor<GPUDevice> {
|
||||||
|
Status operator()(OpKernelContext* ctx,
|
||||||
|
TTypes<int64>::ConstVec host_dense_shape,
|
||||||
|
TTypes<int>::ConstVec host_batch_ptrs,
|
||||||
|
TTypes<int>::Vec coo_row_ind,
|
||||||
|
TTypes<int>::ConstVec coo_col_ind,
|
||||||
|
TTypes<int64>::Matrix indices);
|
||||||
|
};
|
||||||
|
extern template struct COOSparseMatrixToSparseTensor<GPUDevice>;
|
||||||
|
|
||||||
|
// TODO(ebrevdo): Write a custom batch-friendly impl of this to update
|
||||||
|
// the SparseTensor indices directly.
|
||||||
|
template <>
|
||||||
|
Status CSRSparseMatrixToCOOSparseMatrix<GPUDevice>::operator()(
|
||||||
|
OpKernelContext* c, TTypes<const int>::UnalignedVec csr_row_ptr,
|
||||||
|
TTypes<int>::UnalignedVec coo_row_ind) {
|
||||||
|
CudaSparse cuda_sparse(c);
|
||||||
|
const int nnz = coo_row_ind.size();
|
||||||
|
TF_RETURN_IF_ERROR(cuda_sparse.Initialize());
|
||||||
|
const int m = csr_row_ptr.size() - 1; // rows
|
||||||
|
return cuda_sparse.Csr2coo(csr_row_ptr.data(), nnz, m, coo_row_ind.data());
|
||||||
|
}
|
||||||
|
extern template struct CSRSparseMatrixToCOOSparseMatrix<GPUDevice>;
|
||||||
|
|
||||||
|
} // namespace functor
|
||||||
|
|
||||||
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
|
#define REGISTER_CPU(T) \
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("CSRSparseMatrixToSparseTensor") \
|
||||||
|
.Device(DEVICE_CPU) \
|
||||||
|
.TypeConstraint<T>("type"), \
|
||||||
|
CSRSparseMatrixToSparseTensorCPUOp<T>);
|
||||||
|
|
||||||
|
REGISTER_CPU(float)
|
||||||
|
REGISTER_CPU(double)
|
||||||
|
REGISTER_CPU(complex64)
|
||||||
|
REGISTER_CPU(complex128)
|
||||||
|
|
||||||
|
#undef REGISTER_CPU
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
398
tensorflow/core/kernels/sparse/dense_to_csr_sparse_matrix_op.cc
Normal file
398
tensorflow/core/kernels/sparse/dense_to_csr_sparse_matrix_op.cc
Normal file
@ -0,0 +1,398 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#define EIGEN_USE_THREADS
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
#define EIGEN_USE_GPU
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
|
#include "tensorflow/core/framework/op.h"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
|
#include "tensorflow/core/framework/types.h"
|
||||||
|
#include "tensorflow/core/framework/variant_op_registry.h"
|
||||||
|
#include "tensorflow/core/kernels/dense_update_functor.h"
|
||||||
|
#include "tensorflow/core/kernels/fill_functor.h"
|
||||||
|
#include "tensorflow/core/kernels/gather_nd_op.h"
|
||||||
|
#include "tensorflow/core/kernels/sparse/kernels.h"
|
||||||
|
#include "tensorflow/core/kernels/sparse/sparse_matrix.h"
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
|
||||||
|
#include "tensorflow/core/kernels/cuda_solvers.h"
|
||||||
|
#include "tensorflow/core/kernels/cuda_sparse.h"
|
||||||
|
#include "tensorflow/core/platform/cuda.h"
|
||||||
|
|
||||||
|
using ::perftools::gputools::cuda::ScopedActivateExecutorContext;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||||
|
typedef Eigen::GpuDevice GPUDevice;
|
||||||
|
|
||||||
|
// Op to convert dense matrices to CSR SparseMatrices on the CPU.
|
||||||
|
// Takes a Tensor of rank 2 or (if batched) 3 and a corresponding list of
|
||||||
|
// indices as input.
|
||||||
|
//
|
||||||
|
// The (batched) CSR SparseMatrix is constructed using only
|
||||||
|
// the values at the given indices. This implementation assumes that the indices
|
||||||
|
// are sorted with respect to batch indices and are in row-major order.
|
||||||
|
template <typename Device, typename T>
|
||||||
|
class DenseToCSRSparseMatrixCPUOp : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit DenseToCSRSparseMatrixCPUOp(OpKernelConstruction* c) : OpKernel(c) {}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* ctx) override {
|
||||||
|
const Tensor& params = ctx->input(0);
|
||||||
|
const Tensor& indices = ctx->input(1);
|
||||||
|
|
||||||
|
// TODO(anudhyan): Factor out common input validation for CPU and GPU ops
|
||||||
|
// into a single function.
|
||||||
|
const TensorShape& dense_tensor_shape = params.shape();
|
||||||
|
const int rank = params.dims();
|
||||||
|
OP_REQUIRES(ctx, rank == 2 || rank == 3,
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"params must have rank == 2 or 3; ",
|
||||||
|
"but saw shape: ", dense_tensor_shape.DebugString()));
|
||||||
|
OP_REQUIRES(
|
||||||
|
ctx, indices.dims() == 2,
|
||||||
|
errors::InvalidArgument("indices must be a matrix, but saw shape: ",
|
||||||
|
indices.shape().DebugString()));
|
||||||
|
OP_REQUIRES(
|
||||||
|
ctx, indices.dim_size(1) == rank,
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"indices.shape[1] must be equal to the rank of params, but saw: ",
|
||||||
|
indices.dim_size(1), " vs. ", rank));
|
||||||
|
|
||||||
|
Tensor dense_shape(cpu_allocator(), DT_INT64, TensorShape({rank}));
|
||||||
|
auto dense_shape_mutable = dense_shape.vec<int64>();
|
||||||
|
for (int i = 0; i < rank; ++i) {
|
||||||
|
dense_shape_mutable(i) = dense_tensor_shape.dim_size(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
const int64 batch_size = (rank == 2) ? 1 : dense_tensor_shape.dim_size(0);
|
||||||
|
const int64 num_rows = dense_tensor_shape.dim_size((rank == 2) ? 0 : 1);
|
||||||
|
const int64 total_nnz = indices.NumElements() / rank;
|
||||||
|
|
||||||
|
Tensor values;
|
||||||
|
OP_REQUIRES_OK(ctx, functor::DoGatherNd<Device, T, int64>(
|
||||||
|
ctx, params, indices, &values));
|
||||||
|
|
||||||
|
Tensor batch_ptr(cpu_allocator(), DT_INT32, TensorShape({batch_size + 1}));
|
||||||
|
Tensor csr_col_ind(cpu_allocator(), DT_INT32, TensorShape({total_nnz}));
|
||||||
|
Tensor csr_row_ptr(cpu_allocator(), DT_INT32,
|
||||||
|
TensorShape({(num_rows + 1) * batch_size}));
|
||||||
|
|
||||||
|
// Fill the row pointers with zeros.
|
||||||
|
functor::SetZeroFunctor<Device, int32> set_zero;
|
||||||
|
set_zero(ctx->eigen_device<Device>(), csr_row_ptr.flat<int32>());
|
||||||
|
|
||||||
|
// Convert from COO to CSR format.
|
||||||
|
functor::SparseTensorToCSRSparseMatrixCPUFunctor coo_to_csr;
|
||||||
|
OP_REQUIRES_OK(ctx,
|
||||||
|
coo_to_csr(batch_size, num_rows, indices.matrix<int64>(),
|
||||||
|
batch_ptr.vec<int32>(), csr_row_ptr.vec<int32>(),
|
||||||
|
csr_col_ind.vec<int32>()));
|
||||||
|
|
||||||
|
CSRSparseMatrix output_csr_matrix;
|
||||||
|
OP_REQUIRES_OK(ctx, CSRSparseMatrix::CreateCSRSparseMatrix(
|
||||||
|
values.dtype(), dense_shape, batch_ptr, csr_row_ptr,
|
||||||
|
csr_col_ind, values, &output_csr_matrix));
|
||||||
|
Tensor* output_csr_matrix_tensor;
|
||||||
|
AllocatorAttributes cpu_alloc;
|
||||||
|
cpu_alloc.set_on_host(true);
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
ctx, ctx->allocate_output(0, TensorShape({}), &output_csr_matrix_tensor,
|
||||||
|
cpu_alloc));
|
||||||
|
output_csr_matrix_tensor->scalar<Variant>()() =
|
||||||
|
std::move(output_csr_matrix);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
#define REGISTER_CPU(T) \
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("DenseToCSRSparseMatrix") \
|
||||||
|
.Device(DEVICE_CPU) \
|
||||||
|
.TypeConstraint<T>("T"), \
|
||||||
|
DenseToCSRSparseMatrixCPUOp<CPUDevice, T>);
|
||||||
|
|
||||||
|
REGISTER_CPU(float)
|
||||||
|
REGISTER_CPU(double)
|
||||||
|
REGISTER_CPU(complex64)
|
||||||
|
REGISTER_CPU(complex128)
|
||||||
|
|
||||||
|
#undef REGISTER_CPU
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
|
||||||
|
template <typename Device, typename T>
|
||||||
|
class DenseToCSRSparseMatrixGPUOp : public AsyncOpKernel {
|
||||||
|
public:
|
||||||
|
explicit DenseToCSRSparseMatrixGPUOp(OpKernelConstruction* c)
|
||||||
|
: AsyncOpKernel(c) {}
|
||||||
|
|
||||||
|
void ComputeAsync(OpKernelContext* c, DoneCallback done) final {
|
||||||
|
auto stream = c->op_device_context()->stream();
|
||||||
|
const Device& d = c->eigen_device<Device>();
|
||||||
|
|
||||||
|
const Tensor& params_t = c->input(0);
|
||||||
|
const Tensor& indices_t = c->input(1);
|
||||||
|
const TensorShape& dense_tensor_shape = params_t.shape();
|
||||||
|
const int rank = params_t.dims();
|
||||||
|
OP_REQUIRES_ASYNC(c, rank == 2 || rank == 3,
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"params must have rank == 2 or 3; ",
|
||||||
|
"but saw shape: ", dense_tensor_shape.DebugString()),
|
||||||
|
done);
|
||||||
|
OP_REQUIRES_ASYNC(
|
||||||
|
c, indices_t.dims() == 2,
|
||||||
|
errors::InvalidArgument("indices must be a matrix, but saw shape: ",
|
||||||
|
indices_t.shape().DebugString()),
|
||||||
|
done);
|
||||||
|
OP_REQUIRES_ASYNC(
|
||||||
|
c, indices_t.dim_size(1) == rank,
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"indices.shape[1] must be equal to the rank of params, but saw: ",
|
||||||
|
indices_t.dim_size(1), " vs. ", rank),
|
||||||
|
done);
|
||||||
|
const int64 batch_size = (rank == 2) ? 1 : dense_tensor_shape.dim_size(0);
|
||||||
|
const int64 rows = dense_tensor_shape.dim_size((rank == 2) ? 0 : 1);
|
||||||
|
const int64 cols = dense_tensor_shape.dim_size((rank == 2) ? 1 : 2);
|
||||||
|
|
||||||
|
ScratchSpace<int32> nnz_per_batch_host(c, batch_size, /*on_host*/ true);
|
||||||
|
|
||||||
|
Tensor nnz_per_batch_device_t;
|
||||||
|
if (rank == 2) {
|
||||||
|
// Simple case.
|
||||||
|
nnz_per_batch_host.mutable_data()[0] = indices_t.dim_size(0);
|
||||||
|
} else {
|
||||||
|
OP_REQUIRES_OK_ASYNC(c,
|
||||||
|
c->allocate_temp(DT_INT32, TensorShape({batch_size}),
|
||||||
|
&nnz_per_batch_device_t),
|
||||||
|
done);
|
||||||
|
auto nnz_per_batch_device = nnz_per_batch_device_t.vec<int32>();
|
||||||
|
|
||||||
|
functor::CalculateNNZPerBatchMatrixFromIndices<Device>
|
||||||
|
calculate_nnz_from_indices;
|
||||||
|
auto indices = indices_t.matrix<int64>();
|
||||||
|
OP_REQUIRES_OK_ASYNC(
|
||||||
|
c, calculate_nnz_from_indices(c, indices, nnz_per_batch_device),
|
||||||
|
done);
|
||||||
|
|
||||||
|
perftools::gputools::DeviceMemoryBase nnz_per_batch_device_ptr(
|
||||||
|
static_cast<void*>(nnz_per_batch_device.data()));
|
||||||
|
|
||||||
|
OP_REQUIRES_ASYNC(
|
||||||
|
c,
|
||||||
|
stream
|
||||||
|
->ThenMemcpy(nnz_per_batch_host.mutable_data() /*host_dst*/,
|
||||||
|
nnz_per_batch_device_ptr /*gpu_src*/,
|
||||||
|
batch_size * sizeof(int32) /*size*/)
|
||||||
|
.ok(),
|
||||||
|
errors::Internal("DenseToSparseMatrixGPUOp: failed to copy "
|
||||||
|
"nnz_per_batch from device"),
|
||||||
|
done);
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(ebrevdo): write a custom pair of kernels: one that
|
||||||
|
// calculates the batched csr_row_ptr vector, another that fills in
|
||||||
|
// the col_ind and values vectors.
|
||||||
|
TensorReference nnz_per_batch_device_ref(nnz_per_batch_device_t);
|
||||||
|
auto convert_to_csr = [this, c, rank, batch_size, nnz_per_batch_host,
|
||||||
|
nnz_per_batch_device_ref, stream, &d, ¶ms_t,
|
||||||
|
&indices_t, dense_tensor_shape, rows, cols, done]() {
|
||||||
|
// The data has been copied out of the nnz_per_batch_device
|
||||||
|
// tensor by the time we get here; we can unreference it.
|
||||||
|
nnz_per_batch_device_ref.Unref();
|
||||||
|
|
||||||
|
auto nnz_per_batch = nnz_per_batch_host.tensor().vec<int32>();
|
||||||
|
|
||||||
|
// Ensure that within the callback, the proper GPU settings are
|
||||||
|
// configured.
|
||||||
|
ScopedActivateExecutorContext scoped_activation{stream->parent()};
|
||||||
|
|
||||||
|
// Extract out the values.
|
||||||
|
Tensor temp_values_t;
|
||||||
|
OP_REQUIRES_OK_ASYNC(c,
|
||||||
|
(functor::DoGatherNd<Device, T, int64>(
|
||||||
|
c, params_t, indices_t, &temp_values_t)),
|
||||||
|
done);
|
||||||
|
const Tensor& values_t = const_cast<const Tensor&>(temp_values_t);
|
||||||
|
|
||||||
|
OP_REQUIRES_ASYNC(
|
||||||
|
c, TensorShapeUtils::IsVector(values_t.shape()),
|
||||||
|
errors::Internal("Expected values_t to be a vector, but saw shape: ",
|
||||||
|
values_t.shape().DebugString()),
|
||||||
|
done);
|
||||||
|
|
||||||
|
Tensor dense_shape_t(cpu_allocator(), DT_INT64, TensorShape({rank}));
|
||||||
|
auto dense_shape_mutable = dense_shape_t.vec<int64>();
|
||||||
|
for (int i = 0; i < rank; ++i) {
|
||||||
|
dense_shape_mutable(i) = dense_tensor_shape.dim_size(i);
|
||||||
|
}
|
||||||
|
auto dense_shape = const_cast<const Tensor&>(dense_shape_t).vec<int64>();
|
||||||
|
|
||||||
|
Tensor batch_ptr_t(cpu_allocator(), DT_INT32,
|
||||||
|
TensorShape({batch_size + 1}));
|
||||||
|
auto batch_ptr = batch_ptr_t.vec<int32>();
|
||||||
|
auto indices = indices_t.matrix<int64>();
|
||||||
|
|
||||||
|
batch_ptr(0) = 0;
|
||||||
|
for (int i = 0; i < batch_size; ++i) {
|
||||||
|
batch_ptr(i + 1) = batch_ptr(i) + nnz_per_batch(i);
|
||||||
|
}
|
||||||
|
int total_nnz = batch_ptr(batch_size);
|
||||||
|
OP_REQUIRES_ASYNC(
|
||||||
|
c, total_nnz == values_t.NumElements(),
|
||||||
|
errors::Internal("nnz returned by "
|
||||||
|
"CalculateNNZPerBatchMatrixFromInd"
|
||||||
|
"ices != len(values): ",
|
||||||
|
total_nnz, " vs. ", values_t.NumElements()),
|
||||||
|
done);
|
||||||
|
|
||||||
|
Tensor coo_col_ind_t;
|
||||||
|
Tensor csr_row_ptr_t;
|
||||||
|
Tensor csr_values_t = values_t;
|
||||||
|
|
||||||
|
Tensor coo_row_ind_t;
|
||||||
|
OP_REQUIRES_OK_ASYNC(
|
||||||
|
c,
|
||||||
|
c->allocate_temp(DT_INT32, TensorShape({total_nnz}), &coo_row_ind_t),
|
||||||
|
done);
|
||||||
|
OP_REQUIRES_OK_ASYNC(
|
||||||
|
c,
|
||||||
|
c->allocate_temp(DT_INT32, TensorShape({total_nnz}), &coo_col_ind_t),
|
||||||
|
done);
|
||||||
|
OP_REQUIRES_OK_ASYNC(
|
||||||
|
c,
|
||||||
|
c->allocate_temp(DT_INT32, TensorShape({batch_size * (rows + 1)}),
|
||||||
|
&csr_row_ptr_t),
|
||||||
|
done);
|
||||||
|
|
||||||
|
auto coo_row_ind = coo_row_ind_t.vec<int32>();
|
||||||
|
auto coo_col_ind = coo_col_ind_t.vec<int32>();
|
||||||
|
auto csr_row_ptr = csr_row_ptr_t.vec<int32>();
|
||||||
|
|
||||||
|
// Convert SparseTensor rep to coo row ind, coo col ind.
|
||||||
|
if (total_nnz > 0) {
|
||||||
|
functor::SparseTensorToCOOSparseMatrix<Device> st_to_coo;
|
||||||
|
st_to_coo(d, dense_shape, indices, coo_row_ind, coo_col_ind);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set all csr row pointers to zero, so that when iterating over
|
||||||
|
// batches converting coo to csr, we do not have to perform an
|
||||||
|
// unaligned SetZero for any nnz == 0 minibatches. coo2csr has
|
||||||
|
// a bug if you have empty coo rows.
|
||||||
|
// TODO(ebrevdo): File bug w/ nvidia so coo2csr can handle
|
||||||
|
// zero-element input coo rows.
|
||||||
|
functor::SetZeroFunctor<Device, int32> set_zero;
|
||||||
|
set_zero(d, csr_row_ptr_t.flat<int32>());
|
||||||
|
|
||||||
|
functor::COOSparseMatrixToCSRSparseMatrix<Device> coo_to_csr;
|
||||||
|
for (int i = 0; i < batch_size; ++i) {
|
||||||
|
int nnz_i = batch_ptr(i + 1) - batch_ptr(i);
|
||||||
|
if (nnz_i == 0) {
|
||||||
|
// This is an empty minibatch; no call to coo2csr: it's
|
||||||
|
// handled by the SetZero above.
|
||||||
|
} else {
|
||||||
|
// Convert coo to csr.
|
||||||
|
auto coo_row_ind_i =
|
||||||
|
TTypes<int32>::UnalignedVec(&coo_row_ind(batch_ptr(i)), nnz_i);
|
||||||
|
auto csr_row_ptr_i = TTypes<int32>::UnalignedVec(
|
||||||
|
&csr_row_ptr((rows + 1) * i), rows + 1);
|
||||||
|
OP_REQUIRES_OK_ASYNC(
|
||||||
|
c, coo_to_csr(c, rows, cols, coo_row_ind_i, csr_row_ptr_i), done);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
CSRSparseMatrix matrix;
|
||||||
|
OP_REQUIRES_OK_ASYNC(
|
||||||
|
c,
|
||||||
|
CSRSparseMatrix::CreateCSRSparseMatrix(
|
||||||
|
values_t.dtype(), dense_shape_t, batch_ptr_t, csr_row_ptr_t,
|
||||||
|
coo_col_ind_t, csr_values_t, &matrix),
|
||||||
|
done);
|
||||||
|
Tensor* matrix_t;
|
||||||
|
AllocatorAttributes cpu_alloc;
|
||||||
|
cpu_alloc.set_on_host(true);
|
||||||
|
OP_REQUIRES_OK_ASYNC(
|
||||||
|
c, c->allocate_output(0, TensorShape({}), &matrix_t, cpu_alloc),
|
||||||
|
done);
|
||||||
|
matrix_t->scalar<Variant>()() = std::move(matrix);
|
||||||
|
|
||||||
|
done();
|
||||||
|
};
|
||||||
|
|
||||||
|
if (rank == 2) {
|
||||||
|
convert_to_csr();
|
||||||
|
} else {
|
||||||
|
// Launch the GPU kernel to count nnz entries, then call convert_to_csr.
|
||||||
|
c->device()->tensorflow_gpu_device_info()->event_mgr->ThenExecute(
|
||||||
|
stream, convert_to_csr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
#define REGISTER_GPU(DEV, T) \
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("DenseToCSRSparseMatrix") \
|
||||||
|
.Device(DEVICE_##DEV) \
|
||||||
|
.TypeConstraint<T>("T"), \
|
||||||
|
DenseToCSRSparseMatrixGPUOp<DEV##Device, T>);
|
||||||
|
|
||||||
|
REGISTER_GPU(GPU, float)
|
||||||
|
REGISTER_GPU(GPU, double)
|
||||||
|
REGISTER_GPU(GPU, complex64)
|
||||||
|
REGISTER_GPU(GPU, complex128)
|
||||||
|
|
||||||
|
namespace functor {
|
||||||
|
|
||||||
|
template <>
|
||||||
|
Status CalculateNNZPerBatchMatrixFromIndices<GPUDevice>::operator()(
|
||||||
|
OpKernelContext* c, TTypes<int64>::ConstMatrix indices,
|
||||||
|
TTypes<int32>::Vec nnz_per_batch);
|
||||||
|
extern template struct CalculateNNZPerBatchMatrixFromIndices<GPUDevice>;
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct SparseTensorToCOOSparseMatrix<GPUDevice> {
|
||||||
|
void operator()(const GPUDevice& d, TTypes<int64>::ConstVec host_dense_shape,
|
||||||
|
TTypes<int64>::ConstMatrix indices,
|
||||||
|
TTypes<int>::Vec coo_row_ind, TTypes<int>::Vec coo_col_ind);
|
||||||
|
};
|
||||||
|
extern template struct SparseTensorToCOOSparseMatrix<GPUDevice>;
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct COOSparseMatrixToCSRSparseMatrix<GPUDevice> {
|
||||||
|
Status operator()(OpKernelContext* c, const int rows, const int cols,
|
||||||
|
TTypes<int>::UnalignedVec coo_row_ind,
|
||||||
|
TTypes<int>::UnalignedVec csr_row_ptr) {
|
||||||
|
CudaSparse cuda_sparse(c);
|
||||||
|
TF_RETURN_IF_ERROR(cuda_sparse.Initialize());
|
||||||
|
return cuda_sparse.Coo2csr(coo_row_ind.data(),
|
||||||
|
/*nnz*/ coo_row_ind.size(),
|
||||||
|
/*m == rows of A*/ rows, csr_row_ptr.data());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
extern template struct COOSparseMatrixToCSRSparseMatrix<GPUDevice>;
|
||||||
|
|
||||||
|
} // namespace functor
|
||||||
|
|
||||||
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
|
#undef REGISTER_GPU
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
100
tensorflow/core/kernels/sparse/kernels.cc
Normal file
100
tensorflow/core/kernels/sparse/kernels.cc
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/kernels/sparse/kernels.h"
|
||||||
|
|
||||||
|
#include <numeric>
|
||||||
|
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace functor {
|
||||||
|
|
||||||
|
Status SparseTensorToCSRSparseMatrixCPUFunctor::operator()(
|
||||||
|
const int64 batch_size, const int num_rows,
|
||||||
|
TTypes<int64>::ConstMatrix indices, TTypes<int32>::Vec batch_ptr,
|
||||||
|
TTypes<int32>::Vec csr_row_ptr, TTypes<int32>::Vec csr_col_ind) {
|
||||||
|
// Validate inputs.
|
||||||
|
if (batch_ptr.size() != batch_size + 1) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Expected batch_ptr.size() == batch_size + 1. Got: ", batch_ptr.size(),
|
||||||
|
" vs. ", batch_size + 1);
|
||||||
|
}
|
||||||
|
if (csr_row_ptr.size() != batch_size * (num_rows + 1)) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Expected csr_row_ptr.size() == batch_size * (num_rows + 1). Got: ",
|
||||||
|
csr_row_ptr.size(), " vs. ", batch_size * (num_rows + 1));
|
||||||
|
}
|
||||||
|
|
||||||
|
const int64 total_nnz = indices.dimension(0);
|
||||||
|
const int rank = indices.dimension(1);
|
||||||
|
if (rank == 2 && batch_size != 1) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Expected batch_size == 1 when rank is 2. Got batch_size: ",
|
||||||
|
batch_size);
|
||||||
|
}
|
||||||
|
if (csr_col_ind.size() != total_nnz) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Expected csr_col_ind.size() == total_nnz. Got: ", csr_col_ind.size(),
|
||||||
|
" vs. ", total_nnz);
|
||||||
|
}
|
||||||
|
|
||||||
|
int prev_batch = -1;
|
||||||
|
if (rank == 2) {
|
||||||
|
// For a single batch, the batch_ptrs are {0, total_nnz}.
|
||||||
|
batch_ptr(0) = 0;
|
||||||
|
++prev_batch;
|
||||||
|
|
||||||
|
for (int64 i = 0; i < total_nnz; ++i) {
|
||||||
|
// For now, the rows pointers store the corresponding row counts.
|
||||||
|
csr_row_ptr(indices(i, 0) + 1) += 1;
|
||||||
|
csr_col_ind(i) = indices(i, 1);
|
||||||
|
}
|
||||||
|
} else { // rank == 3
|
||||||
|
for (int64 i = 0; i < total_nnz; ++i) {
|
||||||
|
const int cur_batch = indices(i, 0);
|
||||||
|
// For now, the rows pointers store the corresponding row counts.
|
||||||
|
csr_row_ptr(cur_batch * (num_rows + 1) + indices(i, 1) + 1) += 1;
|
||||||
|
csr_col_ind(i) = indices(i, 2);
|
||||||
|
|
||||||
|
// We're at a new batch and might have skipped over empty batches.
|
||||||
|
while (prev_batch < cur_batch) {
|
||||||
|
// The previous batch ends at position i.
|
||||||
|
batch_ptr(prev_batch + 1) = i;
|
||||||
|
++prev_batch;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Set the last element of batch_ptr and account for trailing empty batches.
|
||||||
|
while (prev_batch < batch_size) {
|
||||||
|
batch_ptr(prev_batch + 1) = total_nnz;
|
||||||
|
++prev_batch;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute the cumulative row counts for each batch.
|
||||||
|
for (int batch_idx = 0; batch_idx < batch_size; ++batch_idx) {
|
||||||
|
auto* row_ptr_batch = csr_row_ptr.data() + batch_idx * (num_rows + 1);
|
||||||
|
std::partial_sum(row_ptr_batch, row_ptr_batch + num_rows + 1,
|
||||||
|
row_ptr_batch);
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace functor
|
||||||
|
} // namespace tensorflow
|
247
tensorflow/core/kernels/sparse/kernels.h
Normal file
247
tensorflow/core/kernels/sparse/kernels.h
Normal file
@ -0,0 +1,247 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_CORE_KERNELS_SPARSE_KERNELS_H_
|
||||||
|
#define TENSORFLOW_CORE_KERNELS_SPARSE_KERNELS_H_
|
||||||
|
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
|
#include "tensorflow/core/kernels/sparse/sparse_matrix.h"
|
||||||
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
namespace functor {
|
||||||
|
|
||||||
|
// Calculates number of nonzero entries per batch of a sorted rank-3
|
||||||
|
// SparseTensor's indices. indices is expected to have columns
|
||||||
|
// corresponding to [batch, row, column], where indices[:,0] < B.
|
||||||
|
//
|
||||||
|
// REQUIRES:
|
||||||
|
// indices.dimension(1) == 3
|
||||||
|
// nnz_per_batch.dimension(0) == B
|
||||||
|
template <typename Device>
|
||||||
|
struct CalculateNNZPerBatchMatrixFromIndices {
|
||||||
|
Status operator()(OpKernelContext* c, TTypes<int64>::ConstMatrix indices,
|
||||||
|
TTypes<int32>::Vec nnz_per_batch);
|
||||||
|
};
|
||||||
|
|
||||||
|
// Split a subset of a SparseTensors' indices into two vectors:
|
||||||
|
// COO row inds and COO col inds. Outputs are:
|
||||||
|
//
|
||||||
|
// coo_row_ind = indices[:, row_dim]
|
||||||
|
// coo_col_ind = indices[:, row_dim + 1]
|
||||||
|
//
|
||||||
|
// where n = coo_row_ind.size()
|
||||||
|
// and row_dim = #cols(indices) - 1
|
||||||
|
//
|
||||||
|
// REQUIRES:
|
||||||
|
// host_dense_shape.size() in [2, 3]
|
||||||
|
// indices.dim_size(1) == host_dense_shape.size()
|
||||||
|
// coo_row_ind.size() == coo_col_ind.size()
|
||||||
|
// coo_row_ind.size() == indices.dim_size(0)
|
||||||
|
template <typename Device>
|
||||||
|
struct SparseTensorToCOOSparseMatrix {
|
||||||
|
void operator()(const Device& d, TTypes<int64>::ConstVec host_dense_shape,
|
||||||
|
TTypes<int64>::ConstMatrix indices,
|
||||||
|
TTypes<int32>::Vec coo_row_ind,
|
||||||
|
TTypes<int32>::Vec coo_col_ind);
|
||||||
|
};
|
||||||
|
|
||||||
|
// Write coo batch, row, and column vectors to output matrix indices:
|
||||||
|
//
|
||||||
|
// indices[:, row_dim] = coo_row_ind
|
||||||
|
// indices[:, col_dim] = coo_col_ind
|
||||||
|
//
|
||||||
|
// where row_dim = #cols(indices) - 1 and n = coo_row_ind.size().
|
||||||
|
// In addition, if #cols(indices) == 3, also store the batch:
|
||||||
|
//
|
||||||
|
// indices[i, 0] = batch_of(i) where
|
||||||
|
// host_batch_ptrs(batch_of(i)) <= i < host_batch_ptrs(batch_of(i) + 1)
|
||||||
|
//
|
||||||
|
// REQUIRES:
|
||||||
|
//
|
||||||
|
// host_dense_shape.size() in [2, 3]
|
||||||
|
// indices.dim_size(1) == host_dense_shape.size()
|
||||||
|
// host_batch_ptr.size() ==
|
||||||
|
// coo_row_ind.size() == coo_col_ind.size()
|
||||||
|
//
|
||||||
|
template <typename Device>
|
||||||
|
struct COOSparseMatrixToSparseTensor {
|
||||||
|
Status operator()(OpKernelContext* c,
|
||||||
|
TTypes<int64>::ConstVec host_dense_shape,
|
||||||
|
TTypes<int32>::ConstVec host_batch_ptrs,
|
||||||
|
TTypes<int32>::Vec coo_row_ind,
|
||||||
|
TTypes<int32>::ConstVec coo_col_ind,
|
||||||
|
TTypes<int64>::Matrix indices);
|
||||||
|
};
|
||||||
|
|
||||||
|
// Convert a vector of coo row indices to csr row pointers.
|
||||||
|
//
|
||||||
|
// REQUIRES:
|
||||||
|
//
|
||||||
|
// csr_row_ptr.size() == rows + 1.
|
||||||
|
// max(coo_row_ptr) < rows.
|
||||||
|
//
|
||||||
|
template <typename Device>
|
||||||
|
struct COOSparseMatrixToCSRSparseMatrix {
|
||||||
|
Status operator()(OpKernelContext* c, const int rows, const int cols,
|
||||||
|
TTypes<int32>::UnalignedVec coo_row_ind,
|
||||||
|
TTypes<int32>::UnalignedVec csr_row_ptr);
|
||||||
|
};
|
||||||
|
|
||||||
|
// Convert a matrix of (batched) coo row and column indices to CSR SparseMatrix
|
||||||
|
// batch ptrs, csr row pointers and coo column indices.
|
||||||
|
//
|
||||||
|
// REQUIRES:
|
||||||
|
// batch_ptr.size() == batch_size + 1
|
||||||
|
// csr_row_ptr.size() == batch_size * (num_rows + 1)
|
||||||
|
// csr_col_ind.size() == total_nnz
|
||||||
|
// batch_size == 1 if rank == 2
|
||||||
|
//
|
||||||
|
// where
|
||||||
|
// total_nnz = indices.dim_size(0)
|
||||||
|
// rank = indices.dim_size(1)
|
||||||
|
// Also csr_row_ptr should be initially filled with zeros.
|
||||||
|
//
|
||||||
|
struct SparseTensorToCSRSparseMatrixCPUFunctor {
|
||||||
|
Status operator()(const int64 batch_size, const int num_rows,
|
||||||
|
TTypes<int64>::ConstMatrix indices,
|
||||||
|
TTypes<int32>::Vec batch_ptr,
|
||||||
|
TTypes<int32>::Vec csr_row_ptr,
|
||||||
|
TTypes<int32>::Vec csr_col_ind);
|
||||||
|
};
|
||||||
|
|
||||||
|
// Convert a vector of csr row pointers to coo row indices.
|
||||||
|
//
|
||||||
|
// REQUIRES:
|
||||||
|
//
|
||||||
|
// coo_row_ptr.size() == nnz.
|
||||||
|
// csr_row_ptr[-1] == nnz.
|
||||||
|
//
|
||||||
|
template <typename Device>
|
||||||
|
struct CSRSparseMatrixToCOOSparseMatrix {
|
||||||
|
Status operator()(OpKernelContext* c,
|
||||||
|
TTypes<int32>::UnalignedConstVec csr_row_ptr,
|
||||||
|
TTypes<int32>::UnalignedVec coo_row_ind);
|
||||||
|
};
|
||||||
|
|
||||||
|
// Calculates C = matmul(A, B) or C = matmul(A, B)^T, where A is in CSR format
|
||||||
|
// and B and C are dense.
|
||||||
|
template <typename Device, typename T>
|
||||||
|
struct CSRSparseMatrixMatMul {
|
||||||
|
explicit CSRSparseMatrixMatMul(const bool transpose_output);
|
||||||
|
Status Compute(OpKernelContext* ctx, const ConstCSRComponent<T>& a,
|
||||||
|
typename TTypes<T>::ConstMatrix b,
|
||||||
|
typename TTypes<T>::Matrix c);
|
||||||
|
};
|
||||||
|
|
||||||
|
// Calculates y = A * x, y = A^T * x, or y = A^H * x, where A is in CSR format
|
||||||
|
// and x and y are dense vectors.
|
||||||
|
template <typename Device, typename T>
|
||||||
|
class CSRSparseMatrixMatVec {
|
||||||
|
CSRSparseMatrixMatVec(bool transpose_a, bool adjoint_a);
|
||||||
|
Status Compute(OpKernelContext* ctx, const ConstCSRComponent<T>& a,
|
||||||
|
const T* x, T* y);
|
||||||
|
};
|
||||||
|
|
||||||
|
// Calculates C = functor(A, B) where A and B are CSR and C is CSR
|
||||||
|
// with a different sparsity pattern.
|
||||||
|
template <typename Device, typename T>
|
||||||
|
struct CSRStructureModifyingFunctor {
|
||||||
|
virtual ~CSRStructureModifyingFunctor() {}
|
||||||
|
|
||||||
|
virtual Status Initialize() = 0;
|
||||||
|
|
||||||
|
virtual Status GetOutputStructure(const ConstCSRComponent<T>& a,
|
||||||
|
const ConstCSRComponent<T>& b,
|
||||||
|
TTypes<int32>::UnalignedVec c_row_ptr,
|
||||||
|
int* output_nnz) = 0;
|
||||||
|
|
||||||
|
virtual Status Compute(const ConstCSRComponent<T>& a,
|
||||||
|
const ConstCSRComponent<T>& b, CSRComponent<T>* c) = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Calculates C = alpha * A + beta * B, where A and B are in CSR
|
||||||
|
// format, and alpha and beta are scalars on the host.
|
||||||
|
template <typename Device, typename T>
|
||||||
|
struct CSRSparseMatrixAdd : public CSRStructureModifyingFunctor<Device, T> {
|
||||||
|
explicit CSRSparseMatrixAdd(OpKernelContext* ctx, const T alpha,
|
||||||
|
const T beta);
|
||||||
|
};
|
||||||
|
|
||||||
|
// Calculates C = matmul(A, B), where A, B, and C are in CSR format.
|
||||||
|
template <typename Device, typename T>
|
||||||
|
struct CSRSparseSparseMatrixMatMul
|
||||||
|
: public CSRStructureModifyingFunctor<Device, T> {
|
||||||
|
explicit CSRSparseSparseMatrixMatMul(OpKernelContext* ctx, bool transpose_a,
|
||||||
|
bool transpose_b);
|
||||||
|
};
|
||||||
|
|
||||||
|
// Calculates Y = transpose(X) where X and Y are CSR format components.
|
||||||
|
template <typename Device, typename T>
|
||||||
|
struct CSRSparseMatrixTransposeComponent {
|
||||||
|
Status operator()(OpKernelContext* ctx, const ConstCSRComponent<T>& x,
|
||||||
|
CSRComponent<T>* y);
|
||||||
|
};
|
||||||
|
|
||||||
|
// Calculates Y = transpose(X) where X and Y are in CSR format.
|
||||||
|
template <typename Device, typename T>
|
||||||
|
struct CSRSparseMatrixTranspose {
|
||||||
|
Status operator()(OpKernelContext* ctx, bool conjugate,
|
||||||
|
const CSRSparseMatrix& input_matrix,
|
||||||
|
CSRSparseMatrix* output_matrix);
|
||||||
|
};
|
||||||
|
|
||||||
|
// Calculates Y = softmax(X) where X and Y are in CSR format;
|
||||||
|
// missing coefficients in X are treates as -inf (logits of 0 probability).
|
||||||
|
template <typename Device, typename T>
|
||||||
|
struct CSRSparseMatrixSoftmax {
|
||||||
|
Status operator()(OpKernelContext* ctx, const CSRSparseMatrix& logits,
|
||||||
|
typename TTypes<T>::Vec softmax_values);
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Device, typename T>
|
||||||
|
struct CSRSparseMatrixSoftmaxGrad {
|
||||||
|
Status operator()(OpKernelContext* ctx, const CSRSparseMatrix& softmax,
|
||||||
|
const CSRSparseMatrix& grad_softmax,
|
||||||
|
typename TTypes<T>::Vec gradient_values);
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Device, typename T>
|
||||||
|
class CSRSparseMatrixMulScalar {
|
||||||
|
public:
|
||||||
|
explicit CSRSparseMatrixMulScalar() {}
|
||||||
|
|
||||||
|
Status Compute(OpKernelContext* ctx, const CSRSparseMatrix& a,
|
||||||
|
typename TTypes<T>::ConstScalar b, CSRSparseMatrix* c);
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Device, typename T>
|
||||||
|
class CSRSparseMatrixBatchMulVec {
|
||||||
|
public:
|
||||||
|
explicit CSRSparseMatrixBatchMulVec() {}
|
||||||
|
|
||||||
|
Status Compute(OpKernelContext* ctx, const CSRSparseMatrix& a,
|
||||||
|
typename TTypes<T>::ConstFlat b, CSRSparseMatrix* c);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace functor
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_CORE_KERNELS_SPARSE_KERNELS_H_
|
676
tensorflow/core/kernels/sparse/kernels_gpu.cu.cc
Normal file
676
tensorflow/core/kernels/sparse/kernels_gpu.cu.cc
Normal file
@ -0,0 +1,676 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
|
||||||
|
#define EIGEN_USE_GPU
|
||||||
|
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
|
#include "third_party/cub/device/device_histogram.cuh"
|
||||||
|
#include "third_party/cub/iterator/counting_input_iterator.cuh"
|
||||||
|
#include "third_party/cub/iterator/transform_input_iterator.cuh"
|
||||||
|
#include "third_party/gpus/cuda/include/cusparse.h"
|
||||||
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
|
#include "tensorflow/core/kernels/gpu_device_array.h"
|
||||||
|
#include "tensorflow/core/kernels/gpu_device_array_gpu.h"
|
||||||
|
#include "tensorflow/core/kernels/sparse/kernels.h"
|
||||||
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
#include "tensorflow/core/util/gpu_kernel_helper.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
typedef Eigen::GpuDevice GPUDevice;
|
||||||
|
|
||||||
|
namespace functor {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
struct StridedDataReader {
|
||||||
|
StridedDataReader(const int64* begin, int stride)
|
||||||
|
: begin_(begin), stride_(stride) {}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int operator()(int idx) const {
|
||||||
|
return static_cast<int>(ldg(begin_ + idx * stride_));
|
||||||
|
}
|
||||||
|
|
||||||
|
const int64* begin_;
|
||||||
|
const int stride_;
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
template <>
|
||||||
|
Status CalculateNNZPerBatchMatrixFromIndices<GPUDevice>::operator()(
|
||||||
|
OpKernelContext* c, TTypes<int64>::ConstMatrix indices,
|
||||||
|
TTypes<int32>::Vec nnz_per_batch) {
|
||||||
|
const auto& cu_stream = GetGpuStream(c);
|
||||||
|
|
||||||
|
const int total_nnz = indices.dimension(0);
|
||||||
|
const int size = nnz_per_batch.size();
|
||||||
|
|
||||||
|
DCHECK_EQ(indices.rank(), 2);
|
||||||
|
DCHECK_EQ(indices.dimension(1), 3); // batch, row, col
|
||||||
|
|
||||||
|
const int rank = indices.dimension(1);
|
||||||
|
cub::CountingInputIterator<int> row_counter(0);
|
||||||
|
cub::TransformInputIterator<int, StridedDataReader,
|
||||||
|
cub::CountingInputIterator<int>>
|
||||||
|
indices_first_column(row_counter,
|
||||||
|
StridedDataReader(indices.data(), rank));
|
||||||
|
|
||||||
|
std::size_t temp_storage_bytes = 0;
|
||||||
|
|
||||||
|
DCHECK_NE(indices.data(), nullptr);
|
||||||
|
DCHECK_NE(nnz_per_batch.data(), nullptr);
|
||||||
|
|
||||||
|
auto first_success = cub::DeviceHistogram::HistogramEven(
|
||||||
|
/*d_temp_storage*/ nullptr,
|
||||||
|
/*temp_storage_bytes&*/ temp_storage_bytes,
|
||||||
|
/*d_samples*/ indices_first_column,
|
||||||
|
/*d_histogram*/ nnz_per_batch.data(),
|
||||||
|
/*num_levels*/ size + 1,
|
||||||
|
/*lower_level*/ 0,
|
||||||
|
/*upper_level*/ size,
|
||||||
|
/*num_samples*/ total_nnz,
|
||||||
|
/*stream*/ cu_stream);
|
||||||
|
|
||||||
|
if (first_success != cudaSuccess) {
|
||||||
|
return errors::Internal(
|
||||||
|
"SparseTensorToCSRSparseMatrix: Could not launch "
|
||||||
|
"cub::DeviceHistogram::HistogramEven "
|
||||||
|
"to calculate temp_storage_bytes, status: ",
|
||||||
|
cudaGetErrorString(first_success));
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor temp_storage;
|
||||||
|
TF_RETURN_IF_ERROR(c->allocate_temp(
|
||||||
|
DT_INT8, TensorShape({static_cast<int64>(temp_storage_bytes)}),
|
||||||
|
&temp_storage));
|
||||||
|
DCHECK_NE(temp_storage.flat<int8>().data(), nullptr);
|
||||||
|
auto second_success = cub::DeviceHistogram::HistogramEven(
|
||||||
|
/*d_temp_storage*/ temp_storage.flat<int8>().data(),
|
||||||
|
/*temp_storage_bytes&*/ temp_storage_bytes,
|
||||||
|
/*d_samples*/ indices_first_column,
|
||||||
|
/*d_histogram*/ nnz_per_batch.data(),
|
||||||
|
/*num_levels*/ size + 1,
|
||||||
|
/*lower_level*/ 0,
|
||||||
|
/*upper_level*/ size,
|
||||||
|
/*num_samples*/ total_nnz,
|
||||||
|
/*stream*/ cu_stream);
|
||||||
|
|
||||||
|
if (second_success != cudaSuccess) {
|
||||||
|
return errors::Internal(
|
||||||
|
"SparseTensorToCSRSparseMatrix: Could not launch "
|
||||||
|
"cub::DeviceHistogram::HistogramEven "
|
||||||
|
"to count nnz entries per batch. temp_storage_bytes: ",
|
||||||
|
temp_storage_bytes, ", status: ", cudaGetErrorString(second_success));
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int stride>
|
||||||
|
__global__ void SparseTensorToCOOMatrixKernel(const int64* indices,
|
||||||
|
int* coo_rows_out,
|
||||||
|
int* coo_cols_out, int size) {
|
||||||
|
const int offset = (stride == 3) ? 1 : 0;
|
||||||
|
CUDA_1D_KERNEL_LOOP(i, size) {
|
||||||
|
coo_rows_out[i] = static_cast<int>(ldg(indices + i * stride + offset));
|
||||||
|
coo_cols_out[i] = static_cast<int>(ldg(indices + i * stride + offset + 1));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
void SparseTensorToCOOSparseMatrix<GPUDevice>::operator()(
|
||||||
|
const GPUDevice& d, TTypes<int64>::ConstVec host_dense_shape,
|
||||||
|
TTypes<int64>::ConstMatrix indices, TTypes<int>::Vec coo_row_ind,
|
||||||
|
TTypes<int>::Vec coo_col_ind) {
|
||||||
|
const int stride = host_dense_shape.size();
|
||||||
|
DCHECK(stride == 2 || stride == 3);
|
||||||
|
DCHECK_EQ(stride, indices.dimension(1));
|
||||||
|
const int size = coo_row_ind.dimension(0);
|
||||||
|
GpuLaunchConfig config = GetGpuLaunchConfig(size, d);
|
||||||
|
if (stride == 2) {
|
||||||
|
SparseTensorToCOOMatrixKernel<2>
|
||||||
|
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
|
||||||
|
indices.data(), coo_row_ind.data(), coo_col_ind.data(), size);
|
||||||
|
} else {
|
||||||
|
SparseTensorToCOOMatrixKernel<3>
|
||||||
|
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
|
||||||
|
indices.data(), coo_row_ind.data(), coo_col_ind.data(), size);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void COOMatrixToSparseTensorKernel2D(const int* coo_rows,
|
||||||
|
const int* coo_cols,
|
||||||
|
int64* indices_out, int size) {
|
||||||
|
CUDA_1D_KERNEL_LOOP(i, size) {
|
||||||
|
indices_out[i * 2] = static_cast<int64>(ldg(coo_rows + i));
|
||||||
|
indices_out[i * 2 + 1] = static_cast<int64>(ldg(coo_cols + i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ inline int BinarySearchRange(int* range, int n, int x) {
|
||||||
|
int left = 0;
|
||||||
|
int right = n - 1;
|
||||||
|
while (left < right) {
|
||||||
|
int mid = left + (right - left) / 2;
|
||||||
|
if (x < range[mid])
|
||||||
|
right = mid - 1;
|
||||||
|
else if (range[mid + 1] <= x)
|
||||||
|
left = mid + 1;
|
||||||
|
else
|
||||||
|
return mid; // range[mid] <= x < range[mid + 1].
|
||||||
|
}
|
||||||
|
return left;
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void COOMatrixToSparseTensorKernel3D(
|
||||||
|
const int* coo_rows, const int* coo_cols, int64* indices_out,
|
||||||
|
GpuDeviceArrayStruct<int> batch_ptr_s, const int batch_size,
|
||||||
|
const int size) {
|
||||||
|
// Step 1: access the batch ptrs and copy to shared memory.
|
||||||
|
const int* batch_ptr = GetGpuDeviceArrayOnDevice(&batch_ptr_s);
|
||||||
|
extern __shared__ int local_batch_ptr[];
|
||||||
|
for (int i = threadIdx.x; i < batch_size + 1; i += blockDim.x) {
|
||||||
|
local_batch_ptr[i] = batch_ptr[i];
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
CUDA_1D_KERNEL_LOOP(i, size) {
|
||||||
|
// TODO(ebrevdo): Consider special casing batch_size <= 3,
|
||||||
|
// alternatively doing linear instead of binary search. Requires
|
||||||
|
// some benchmarks.
|
||||||
|
const int b = BinarySearchRange(local_batch_ptr, batch_size, i);
|
||||||
|
indices_out[i * 3] = static_cast<int64>(b);
|
||||||
|
indices_out[i * 3 + 1] = static_cast<int64>(ldg(coo_rows + i));
|
||||||
|
indices_out[i * 3 + 2] = static_cast<int64>(ldg(coo_cols + i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
Status COOSparseMatrixToSparseTensor<GPUDevice>::operator()(
|
||||||
|
OpKernelContext* c, TTypes<int64>::ConstVec host_dense_shape,
|
||||||
|
TTypes<int>::ConstVec host_batch_ptr, TTypes<int>::Vec coo_row_ind,
|
||||||
|
TTypes<int>::ConstVec coo_col_ind, TTypes<int64>::Matrix indices) {
|
||||||
|
const int ndims = indices.dimension(1);
|
||||||
|
DCHECK(ndims == 2 || ndims == 3);
|
||||||
|
DCHECK_EQ(ndims, host_dense_shape.size());
|
||||||
|
DCHECK_NE(coo_row_ind.data(), nullptr);
|
||||||
|
DCHECK_NE(coo_col_ind.data(), nullptr);
|
||||||
|
DCHECK_NE(indices.data(), nullptr);
|
||||||
|
const GPUDevice& d = c->eigen_device<GPUDevice>();
|
||||||
|
const int size = coo_row_ind.size();
|
||||||
|
DCHECK_EQ(size, coo_col_ind.size());
|
||||||
|
DCHECK_EQ(size, indices.dimension(0));
|
||||||
|
if (ndims == 2) {
|
||||||
|
GpuLaunchConfig config = GetGpuLaunchConfig(size, d);
|
||||||
|
COOMatrixToSparseTensorKernel2D<<<config.block_count,
|
||||||
|
config.thread_per_block, 0, d.stream()>>>(
|
||||||
|
coo_row_ind.data(), coo_col_ind.data(), indices.data(), size);
|
||||||
|
return Status::OK();
|
||||||
|
} else {
|
||||||
|
const int batch_size = host_dense_shape(0);
|
||||||
|
GpuDeviceArrayOnHost<int> batch_ptr_copy(c, host_batch_ptr.size());
|
||||||
|
TF_RETURN_IF_ERROR(batch_ptr_copy.Init());
|
||||||
|
for (int i = 0; i < batch_size; ++i) {
|
||||||
|
batch_ptr_copy.Set(i, host_batch_ptr(i));
|
||||||
|
}
|
||||||
|
TF_RETURN_IF_ERROR(batch_ptr_copy.Finalize());
|
||||||
|
GpuLaunchConfig config = GetGpuLaunchConfig(size, d);
|
||||||
|
// shared memory stores the batch pointers.
|
||||||
|
const size_t shared_memory_size = sizeof(int) * (batch_size + 1);
|
||||||
|
COOMatrixToSparseTensorKernel3D<<<config.block_count,
|
||||||
|
config.thread_per_block,
|
||||||
|
shared_memory_size, d.stream()>>>(
|
||||||
|
coo_row_ind.data(), coo_col_ind.data(), indices.data(),
|
||||||
|
batch_ptr_copy.data(), batch_size, size);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void CSRSparseMatrixBatchMulVecKernel3D(
|
||||||
|
const T* a_values, const T* b_batch_values, T* c_values,
|
||||||
|
GpuDeviceArrayStruct<int> batch_ptr_s, const int batch_size,
|
||||||
|
const int total_nnz) {
|
||||||
|
// Step 1: Access the batch ptrs and copy to shared memory.
|
||||||
|
// Also copy the per-batch multipliers into shared memory.
|
||||||
|
const int* batch_ptr = GetGpuDeviceArrayOnDevice(&batch_ptr_s);
|
||||||
|
extern __shared__ int local_batch_ptr[];
|
||||||
|
T* local_batch_values =
|
||||||
|
reinterpret_cast<T*>(local_batch_ptr + batch_size + 1);
|
||||||
|
for (int i = threadIdx.x; i < batch_size + 1; i += blockDim.x) {
|
||||||
|
local_batch_ptr[i] = batch_ptr[i];
|
||||||
|
if (i < batch_size) {
|
||||||
|
local_batch_values[i] = b_batch_values[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
CUDA_1D_KERNEL_LOOP(i, total_nnz) {
|
||||||
|
const int b = BinarySearchRange(local_batch_ptr, batch_size, i);
|
||||||
|
c_values[i] = ldg(a_values + i) * local_batch_values[b];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
Status CSRSparseMatrixBatchMulVecImpl(OpKernelContext* ctx,
|
||||||
|
const CSRSparseMatrix& a,
|
||||||
|
typename TTypes<T>::ConstFlat b,
|
||||||
|
CSRSparseMatrix* c) {
|
||||||
|
DCHECK_EQ(a.dims(), 3);
|
||||||
|
const int total_nnz = a.total_nnz();
|
||||||
|
Tensor c_values_t;
|
||||||
|
TF_RETURN_IF_ERROR(ctx->allocate_temp(DataTypeToEnum<T>::value,
|
||||||
|
TensorShape({total_nnz}), &c_values_t));
|
||||||
|
TF_RETURN_IF_ERROR(CSRSparseMatrix::CreateCSRSparseMatrix(
|
||||||
|
DataTypeToEnum<T>::value, a.dense_shape(), a.batch_pointers(),
|
||||||
|
a.row_pointers(), a.col_indices(), c_values_t, c));
|
||||||
|
|
||||||
|
auto a_values = a.values().flat<T>();
|
||||||
|
auto c_values = c_values_t.flat<T>();
|
||||||
|
|
||||||
|
auto host_dense_shape = a.dense_shape().vec<int64>();
|
||||||
|
auto host_batch_ptr = a.batch_pointers().vec<int>();
|
||||||
|
|
||||||
|
const GPUDevice& d = ctx->eigen_device<GPUDevice>();
|
||||||
|
|
||||||
|
const int batch_size = host_dense_shape(0);
|
||||||
|
DCHECK_EQ(b.size(), batch_size);
|
||||||
|
|
||||||
|
GpuDeviceArrayOnHost<int> batch_ptr_copy(ctx, host_batch_ptr.size());
|
||||||
|
TF_RETURN_IF_ERROR(batch_ptr_copy.Init());
|
||||||
|
for (int i = 0; i < batch_size; ++i) {
|
||||||
|
batch_ptr_copy.Set(i, host_batch_ptr(i));
|
||||||
|
}
|
||||||
|
TF_RETURN_IF_ERROR(batch_ptr_copy.Finalize());
|
||||||
|
GpuLaunchConfig config = GetGpuLaunchConfig(total_nnz, d);
|
||||||
|
// shared memory stores the batch pointers.
|
||||||
|
const size_t shared_memory_size =
|
||||||
|
(sizeof(int) * (batch_size + 1) // local batch_pointers.
|
||||||
|
+ sizeof(T) * batch_size); // local copy of b.
|
||||||
|
CSRSparseMatrixBatchMulVecKernel3D<T>
|
||||||
|
<<<config.block_count, config.thread_per_block, shared_memory_size,
|
||||||
|
d.stream()>>>(a_values.data(), b.data(), c_values.data(),
|
||||||
|
batch_ptr_copy.data(), batch_size, total_nnz);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
#define DEFINE_SPARSE_MUL_VEC_GPU(T) \
|
||||||
|
template <> \
|
||||||
|
CSRSparseMatrixBatchMulVec<GPUDevice, T>::CSRSparseMatrixBatchMulVec() {} \
|
||||||
|
template <> \
|
||||||
|
Status CSRSparseMatrixBatchMulVec<GPUDevice, T>::Compute( \
|
||||||
|
OpKernelContext* ctx, const CSRSparseMatrix& a, \
|
||||||
|
typename TTypes<T>::ConstFlat b, CSRSparseMatrix* c) { \
|
||||||
|
return CSRSparseMatrixBatchMulVecImpl<T>(ctx, a, b, c); \
|
||||||
|
}
|
||||||
|
|
||||||
|
DEFINE_SPARSE_MUL_VEC_GPU(float);
|
||||||
|
DEFINE_SPARSE_MUL_VEC_GPU(double);
|
||||||
|
DEFINE_SPARSE_MUL_VEC_GPU(std::complex<float>);
|
||||||
|
DEFINE_SPARSE_MUL_VEC_GPU(std::complex<double>);
|
||||||
|
|
||||||
|
#undef DEFINE_SPARSE_MUL_VEC_GPU
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void CalculateRowSoftmax(const int begin,
|
||||||
|
const int end,
|
||||||
|
const T* logits,
|
||||||
|
T* softmax) {
|
||||||
|
// For each row, calculate the vector:
|
||||||
|
// softmax[row] = exp(shifted_logits[row]) / sum(exp(shifted_logits[row]))
|
||||||
|
// where
|
||||||
|
// shifted_logits[row] = logits[row] - max(logits[row])
|
||||||
|
// are the logits normalized for stability.
|
||||||
|
T row_max = Eigen::NumTraits<T>::lowest();
|
||||||
|
for (int r_i = begin; r_i < end; ++r_i) {
|
||||||
|
row_max = Eigen::numext::maxi(row_max, ldg(logits + r_i));
|
||||||
|
}
|
||||||
|
T sum_exp = 0;
|
||||||
|
for (int r_i = begin; r_i < end; ++r_i) {
|
||||||
|
const T exp_i = Eigen::numext::exp(ldg(logits + r_i) - row_max);
|
||||||
|
softmax[r_i] = exp_i;
|
||||||
|
sum_exp += exp_i;
|
||||||
|
}
|
||||||
|
for (int r_i = begin; r_i < end; ++r_i) {
|
||||||
|
softmax[r_i] = softmax[r_i] / sum_exp;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void CSRSparseMatrixSoftmaxKernel2D(const int rows,
|
||||||
|
const int* row_ptr,
|
||||||
|
const T* logits, T* softmax) {
|
||||||
|
// TODO(ebrevdo): consider something like a merge-path based
|
||||||
|
// algorithm to distribute the work in case the row sizes are
|
||||||
|
// uneven:
|
||||||
|
// http://images.nvidia.com/events/sc15/pdfs/sc15-Merge-Based-Parallel-Sparse-Matrix-Vector-Multiplication-merrill.pdf
|
||||||
|
CUDA_1D_KERNEL_LOOP(row, rows) {
|
||||||
|
CalculateRowSoftmax(ldg(row_ptr + row), ldg(row_ptr + row + 1), logits,
|
||||||
|
softmax);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void CopyFromGpuDeviceArrayToLocal(
|
||||||
|
GpuDeviceArrayStruct<int> cuda_ptr_s, int* local_ptr, int length) {
|
||||||
|
#ifdef __CUDA_ARCH__
|
||||||
|
const int* cuda_ptr = GetGpuDeviceArrayOnDevice(&cuda_ptr_s);
|
||||||
|
for (int i = threadIdx.x; i < length; i += blockDim.x) {
|
||||||
|
local_ptr[i] = cuda_ptr[i];
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void CSRSparseMatrixSoftmaxKernel3D(
|
||||||
|
const int size, const int rows, GpuDeviceArrayStruct<int> batch_ptr_s,
|
||||||
|
const int* row_ptr, const T* logits, T* softmax) {
|
||||||
|
// TODO(ebrevdo): consider something like a merge-path based
|
||||||
|
// algorithm to distribute the work in case the row sizes are
|
||||||
|
// uneven:
|
||||||
|
// http://images.nvidia.com/events/sc15/pdfs/sc15-Merge-Based-Parallel-Sparse-Matrix-Vector-Multiplication-merrill.pdf
|
||||||
|
const int batch_size = size / rows;
|
||||||
|
extern __shared__ int local_batch_ptr[];
|
||||||
|
CopyFromGpuDeviceArrayToLocal(std::move(batch_ptr_s), local_batch_ptr,
|
||||||
|
batch_size + 1);
|
||||||
|
|
||||||
|
CUDA_1D_KERNEL_LOOP(i, size) {
|
||||||
|
const int batch = i / rows;
|
||||||
|
const int row = i % rows;
|
||||||
|
const int batch_offset = local_batch_ptr[batch];
|
||||||
|
const int row_offset = batch * (rows + 1) + row;
|
||||||
|
CalculateRowSoftmax(batch_offset + ldg(row_ptr + row_offset),
|
||||||
|
batch_offset + ldg(row_ptr + row_offset + 1), logits,
|
||||||
|
softmax);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
Status CSRSparseMatrixSoftmaxGPUImpl(OpKernelContext* ctx,
|
||||||
|
const CSRSparseMatrix& logits,
|
||||||
|
typename TTypes<T>::Vec softmax_values) {
|
||||||
|
auto host_dense_shape = logits.dense_shape().vec<int64>();
|
||||||
|
auto host_batch_ptr = logits.batch_pointers().vec<int32>();
|
||||||
|
auto row_ptr = logits.row_pointers().vec<int32>();
|
||||||
|
auto logits_values = logits.values().vec<T>();
|
||||||
|
|
||||||
|
const int ndims = host_dense_shape.size();
|
||||||
|
DCHECK(ndims == 2 || ndims == 3);
|
||||||
|
const GPUDevice& d = ctx->eigen_device<GPUDevice>();
|
||||||
|
if (ndims == 2) {
|
||||||
|
const int rows = host_dense_shape(0);
|
||||||
|
DCHECK_EQ(rows, row_ptr.size() - 1);
|
||||||
|
GpuLaunchConfig config = GetGpuLaunchConfig(rows /*size*/, d);
|
||||||
|
CSRSparseMatrixSoftmaxKernel2D<T>
|
||||||
|
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
|
||||||
|
rows /*size*/, row_ptr.data(), logits_values.data(),
|
||||||
|
softmax_values.data());
|
||||||
|
} else {
|
||||||
|
const int batch_size = host_dense_shape(0);
|
||||||
|
const int rows = host_dense_shape(1);
|
||||||
|
DCHECK_EQ(batch_size, host_batch_ptr.size() - 1);
|
||||||
|
DCHECK_EQ((rows + 1) * batch_size, row_ptr.size());
|
||||||
|
const int size = rows * batch_size;
|
||||||
|
|
||||||
|
GpuDeviceArrayOnHost<int> batch_ptr_copy(ctx, host_batch_ptr.size());
|
||||||
|
TF_RETURN_IF_ERROR(batch_ptr_copy.Init());
|
||||||
|
for (int i = 0; i < host_batch_ptr.size(); ++i) {
|
||||||
|
batch_ptr_copy.Set(i, host_batch_ptr(i));
|
||||||
|
}
|
||||||
|
TF_RETURN_IF_ERROR(batch_ptr_copy.Finalize());
|
||||||
|
|
||||||
|
GpuLaunchConfig config = GetGpuLaunchConfig(size, d);
|
||||||
|
// shared memory stores the batch pointers.
|
||||||
|
const size_t shared_memory_size = sizeof(int) * (batch_size + 1);
|
||||||
|
CSRSparseMatrixSoftmaxKernel3D<T>
|
||||||
|
<<<config.block_count, config.thread_per_block, shared_memory_size,
|
||||||
|
d.stream()>>>(size, rows, batch_ptr_copy.data(), row_ptr.data(),
|
||||||
|
logits_values.data(), softmax_values.data());
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
#define DEFINE_SOFTMAX_GPU(T) \
|
||||||
|
template <> \
|
||||||
|
Status CSRSparseMatrixSoftmax<GPUDevice, T>::operator()( \
|
||||||
|
OpKernelContext* ctx, const CSRSparseMatrix& logits, \
|
||||||
|
typename TTypes<T>::Vec softmax_values) { \
|
||||||
|
return CSRSparseMatrixSoftmaxGPUImpl<T>(ctx, logits, softmax_values); \
|
||||||
|
}
|
||||||
|
|
||||||
|
DEFINE_SOFTMAX_GPU(float);
|
||||||
|
DEFINE_SOFTMAX_GPU(double);
|
||||||
|
|
||||||
|
#undef DEFINE_SOFTMAX_GPU
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void CalculateRowSoftmaxGrad(
|
||||||
|
const int softmax_begin, const int softmax_end, const int* softmax_col_ind,
|
||||||
|
const T* softmax, const int grad_softmax_begin, const int grad_softmax_end,
|
||||||
|
const int* grad_softmax_col_ind, const T* grad_softmax, T* gradient) {
|
||||||
|
// Iterate from
|
||||||
|
// softmax_col_ind[softmax_begin] to
|
||||||
|
// softmax_col_ind[softmax_end]
|
||||||
|
// and from
|
||||||
|
// grad_softmax_col_ind[grad_softmax_begin] to
|
||||||
|
// grad_softmax_col_ind[grad_softmax_end]
|
||||||
|
//
|
||||||
|
// looking for for matching indices. In the softmax indices only, perform:
|
||||||
|
//
|
||||||
|
// gradient = (grad_softmax - sum(grad_softmax * softmax)) * softmax
|
||||||
|
//
|
||||||
|
// where the sum is along the given row.
|
||||||
|
T sum_prod = 0;
|
||||||
|
for (int i = softmax_begin, j = grad_softmax_begin;
|
||||||
|
i < softmax_end && j < grad_softmax_end;) {
|
||||||
|
const int softmax_col = ldg(softmax_col_ind + i);
|
||||||
|
const int grad_softmax_col = ldg(grad_softmax_col_ind + j);
|
||||||
|
if (softmax_col == grad_softmax_col) {
|
||||||
|
sum_prod += ldg(softmax + i) * ldg(grad_softmax + j);
|
||||||
|
++i;
|
||||||
|
++j;
|
||||||
|
} else if (softmax_col > grad_softmax_col) {
|
||||||
|
++j;
|
||||||
|
} else {
|
||||||
|
++i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find an upper bound on the column numbers in this row; for use in
|
||||||
|
// the special case of a empty grad_softmax row and a non-empty
|
||||||
|
// softmax row.
|
||||||
|
const int softmax_col_upper_bound =
|
||||||
|
(softmax_begin == softmax_end)
|
||||||
|
? -1
|
||||||
|
: ldg(softmax_col_ind + softmax_end - 1) + 1;
|
||||||
|
for (int i = softmax_begin, j = grad_softmax_begin; i < softmax_end;) {
|
||||||
|
const int softmax_col = ldg(softmax_col_ind + i);
|
||||||
|
// We need to keep a large grad_softmax_col value if we're at the
|
||||||
|
// end of the grad_softmax row, so we can fill in the remainder of
|
||||||
|
// the gradients row (the last if branch in this loop).
|
||||||
|
const int grad_softmax_col = (j == grad_softmax_end)
|
||||||
|
? softmax_col_upper_bound
|
||||||
|
: ldg(grad_softmax_col_ind + j);
|
||||||
|
|
||||||
|
if (softmax_col == grad_softmax_col) {
|
||||||
|
gradient[i] = (ldg(grad_softmax + j) - sum_prod) * ldg(softmax + i);
|
||||||
|
++i;
|
||||||
|
++j;
|
||||||
|
} else if (softmax_col > grad_softmax_col) {
|
||||||
|
// grad_softmax is nonzero here, but since softmax is zero, the
|
||||||
|
// gradient is 0; so we skip it since the sparsity structure
|
||||||
|
// already encodes this zero.
|
||||||
|
++j;
|
||||||
|
} else {
|
||||||
|
// grad_softmax is zero but softmax is not.
|
||||||
|
gradient[i] = -sum_prod * ldg(softmax + i);
|
||||||
|
++i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void CSRSparseMatrixSoftmaxGradKernel2D(
|
||||||
|
const int rows, const int* softmax_row_ptr, const int* softmax_col_ind,
|
||||||
|
const T* softmax, const int* grad_softmax_row_ptr,
|
||||||
|
const int* grad_softmax_col_ind, const T* grad_softmax, T* gradient) {
|
||||||
|
// TODO(ebrevdo): consider something like a merge-path based
|
||||||
|
// algorithm to distribute the work in case the row sizes are
|
||||||
|
// uneven:
|
||||||
|
// http://images.nvidia.com/events/sc15/pdfs/sc15-Merge-Based-Parallel-Sparse-Matrix-Vector-Multiplication-merrill.pdf
|
||||||
|
CUDA_1D_KERNEL_LOOP(row, rows) {
|
||||||
|
CalculateRowSoftmaxGrad(
|
||||||
|
ldg(softmax_row_ptr + row) /*softmax_begin*/,
|
||||||
|
ldg(softmax_row_ptr + row + 1) /*softmax_end*/, softmax_col_ind,
|
||||||
|
softmax, ldg(grad_softmax_row_ptr + row) /*grad_softmax_begin*/,
|
||||||
|
ldg(grad_softmax_row_ptr + row + 1) /*grad_softmax_end*/,
|
||||||
|
grad_softmax_col_ind, grad_softmax, gradient);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void CSRSparseMatrixSoftmaxGradKernel3D(
|
||||||
|
const int size, const int rows,
|
||||||
|
GpuDeviceArrayStruct<int> softmax_and_grad_batch_ptr_s,
|
||||||
|
const int* softmax_row_ptr, const int* softmax_col_ind, const T* softmax,
|
||||||
|
const int* grad_softmax_row_ptr, const int* grad_softmax_col_ind,
|
||||||
|
const T* grad_softmax, T* gradient) {
|
||||||
|
// TODO(ebrevdo): consider something like a merge-path based
|
||||||
|
// algorithm to distribute the work in case the row sizes are
|
||||||
|
// uneven:
|
||||||
|
// http://images.nvidia.com/events/sc15/pdfs/sc15-Merge-Based-Parallel-Sparse-Matrix-Vector-Multiplication-merrill.pdf
|
||||||
|
|
||||||
|
const int batch_size = size / rows;
|
||||||
|
extern __shared__ int local_batch_ptr[];
|
||||||
|
CopyFromGpuDeviceArrayToLocal(std::move(softmax_and_grad_batch_ptr_s),
|
||||||
|
local_batch_ptr, 2 * (batch_size + 1));
|
||||||
|
|
||||||
|
#define SOFTMAX_BATCH_PTR(i) local_batch_ptr[i];
|
||||||
|
#define GRAD_SOFTMAX_BATCH_PTR(i) local_batch_ptr[batch_size + 1 + i];
|
||||||
|
|
||||||
|
CUDA_1D_KERNEL_LOOP(i, size) {
|
||||||
|
const int batch = i / rows;
|
||||||
|
const int row = i % rows;
|
||||||
|
const int softmax_batch_offset = SOFTMAX_BATCH_PTR(batch);
|
||||||
|
const int grad_softmax_batch_offset = GRAD_SOFTMAX_BATCH_PTR(batch);
|
||||||
|
const int row_offset = batch * (rows + 1) + row;
|
||||||
|
CalculateRowSoftmaxGrad(
|
||||||
|
softmax_batch_offset +
|
||||||
|
ldg(softmax_row_ptr + row_offset) /*softmax_begin*/,
|
||||||
|
softmax_batch_offset +
|
||||||
|
ldg(softmax_row_ptr + row_offset + 1) /*softmax_end*/,
|
||||||
|
softmax_col_ind, softmax,
|
||||||
|
grad_softmax_batch_offset +
|
||||||
|
ldg(grad_softmax_row_ptr + row_offset) /*grad_softmax_begin*/,
|
||||||
|
grad_softmax_batch_offset +
|
||||||
|
ldg(grad_softmax_row_ptr + row_offset + 1) /*grad_softmax_end*/,
|
||||||
|
grad_softmax_col_ind, grad_softmax, gradient);
|
||||||
|
}
|
||||||
|
|
||||||
|
#undef SOFTMAX_BATCH_PTR
|
||||||
|
#undef GRAD_SOFTMAX_BATCH_PTR
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
Status CSRSparseMatrixSoftmaxGradGPUImpl(
|
||||||
|
OpKernelContext* ctx, const CSRSparseMatrix& softmax,
|
||||||
|
const CSRSparseMatrix& grad_softmax,
|
||||||
|
typename TTypes<T>::Vec gradient_values) {
|
||||||
|
auto host_dense_shape = softmax.dense_shape().vec<int64>();
|
||||||
|
auto softmax_host_batch_ptr = softmax.batch_pointers().vec<int32>();
|
||||||
|
auto softmax_row_ptr = softmax.row_pointers().vec<int32>();
|
||||||
|
auto softmax_col_ind = softmax.col_indices().vec<int32>();
|
||||||
|
auto softmax_values = softmax.values().vec<T>();
|
||||||
|
auto grad_softmax_host_batch_ptr = grad_softmax.batch_pointers().vec<int32>();
|
||||||
|
auto grad_softmax_row_ptr = grad_softmax.row_pointers().vec<int32>();
|
||||||
|
auto grad_softmax_col_ind = grad_softmax.col_indices().vec<int32>();
|
||||||
|
auto grad_softmax_values = grad_softmax.values().vec<T>();
|
||||||
|
|
||||||
|
const int ndims = host_dense_shape.size();
|
||||||
|
DCHECK(ndims == 2 || ndims == 3);
|
||||||
|
const int rows = host_dense_shape(0);
|
||||||
|
const GPUDevice& d = ctx->eigen_device<GPUDevice>();
|
||||||
|
if (ndims == 2) {
|
||||||
|
DCHECK_EQ(rows + 1, softmax_row_ptr.size());
|
||||||
|
DCHECK_EQ(rows + 1, grad_softmax_row_ptr.size());
|
||||||
|
GpuLaunchConfig config = GetGpuLaunchConfig(rows /*size*/, d);
|
||||||
|
CSRSparseMatrixSoftmaxGradKernel2D<T>
|
||||||
|
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
|
||||||
|
rows /*size*/, softmax_row_ptr.data(), softmax_col_ind.data(),
|
||||||
|
softmax_values.data(), grad_softmax_row_ptr.data(),
|
||||||
|
grad_softmax_col_ind.data(), grad_softmax_values.data(),
|
||||||
|
gradient_values.data());
|
||||||
|
} else {
|
||||||
|
const int batch_size = host_dense_shape(0);
|
||||||
|
const int rows = host_dense_shape(1);
|
||||||
|
DCHECK_EQ(batch_size, softmax_host_batch_ptr.size() - 1);
|
||||||
|
DCHECK_EQ(batch_size, grad_softmax_host_batch_ptr.size() - 1);
|
||||||
|
DCHECK_EQ((rows + 1) * batch_size, softmax_row_ptr.size());
|
||||||
|
DCHECK_EQ((rows + 1) * batch_size, grad_softmax_row_ptr.size());
|
||||||
|
const int size = rows * batch_size;
|
||||||
|
// The length of softmax_and_grad_batch_ptr_copy is 2 * (batch_size + 1)
|
||||||
|
// The first (batch_size + 1) entries contain softmax_batch_ptr and
|
||||||
|
// the second (batch_size + 1) entries contain grad_softmax_batch_ptr.
|
||||||
|
GpuDeviceArrayOnHost<int> softmax_and_grad_batch_ptr_copy(
|
||||||
|
ctx, 2 * softmax_host_batch_ptr.size());
|
||||||
|
TF_RETURN_IF_ERROR(softmax_and_grad_batch_ptr_copy.Init());
|
||||||
|
for (int i = 0; i < softmax_host_batch_ptr.size(); ++i) {
|
||||||
|
softmax_and_grad_batch_ptr_copy.Set(i, softmax_host_batch_ptr(i));
|
||||||
|
softmax_and_grad_batch_ptr_copy.Set(batch_size + 1 + i,
|
||||||
|
grad_softmax_host_batch_ptr(i));
|
||||||
|
}
|
||||||
|
TF_RETURN_IF_ERROR(softmax_and_grad_batch_ptr_copy.Finalize());
|
||||||
|
|
||||||
|
GpuLaunchConfig config = GetGpuLaunchConfig(size, d);
|
||||||
|
// shared memory stores two copies of batch pointers: one for the
|
||||||
|
// softmax CSR matrix, one for the grad_softmax CSR matrix.
|
||||||
|
const size_t shared_memory_size = 2 * sizeof(int) * (batch_size + 1);
|
||||||
|
CSRSparseMatrixSoftmaxGradKernel3D<T>
|
||||||
|
<<<config.block_count, config.thread_per_block, shared_memory_size,
|
||||||
|
d.stream()>>>(size, rows, softmax_and_grad_batch_ptr_copy.data(),
|
||||||
|
softmax_row_ptr.data(), softmax_col_ind.data(),
|
||||||
|
softmax_values.data(), grad_softmax_row_ptr.data(),
|
||||||
|
grad_softmax_col_ind.data(),
|
||||||
|
grad_softmax_values.data(), gradient_values.data());
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
#define DEFINE_SOFTMAX_GRAD_GPU(T) \
|
||||||
|
template <> \
|
||||||
|
Status CSRSparseMatrixSoftmaxGrad<GPUDevice, T>::operator()( \
|
||||||
|
OpKernelContext* ctx, const CSRSparseMatrix& softmax, \
|
||||||
|
const CSRSparseMatrix& grad_softmax, \
|
||||||
|
typename TTypes<T>::Vec gradient_values) { \
|
||||||
|
return CSRSparseMatrixSoftmaxGradGPUImpl<T>(ctx, softmax, grad_softmax, \
|
||||||
|
gradient_values); \
|
||||||
|
}
|
||||||
|
|
||||||
|
DEFINE_SOFTMAX_GRAD_GPU(float);
|
||||||
|
DEFINE_SOFTMAX_GRAD_GPU(double);
|
||||||
|
|
||||||
|
#undef DEFINE_SOFTMAX_GRAD_GPU
|
||||||
|
|
||||||
|
} // namespace functor
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // GOOGLE_CUDA
|
82
tensorflow/core/kernels/sparse/kernels_test.cc
Normal file
82
tensorflow/core/kernels/sparse/kernels_test.cc
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/kernels/sparse/kernels.h"
|
||||||
|
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
TEST(SparseTensorToCSRSparseMatrix, SingleBatchConversion) {
|
||||||
|
const auto indices =
|
||||||
|
test::AsTensor<int64>({0, 0, 2, 3, 2, 4, 3, 0}, TensorShape({4, 2}));
|
||||||
|
Tensor batch_ptr(DT_INT32, {2});
|
||||||
|
Tensor csr_col_ind(DT_INT32, {4});
|
||||||
|
auto csr_row_ptr = test::AsTensor<int32>({0, 0, 0, 0, 0});
|
||||||
|
|
||||||
|
functor::SparseTensorToCSRSparseMatrixCPUFunctor coo_to_csr;
|
||||||
|
TF_EXPECT_OK(coo_to_csr(1 /* batch_size */, 4 /* num_rows */,
|
||||||
|
indices.template matrix<int64>(),
|
||||||
|
batch_ptr.vec<int32>(), csr_row_ptr.vec<int32>(),
|
||||||
|
csr_col_ind.vec<int32>()));
|
||||||
|
|
||||||
|
test::ExpectTensorEqual<int32>(batch_ptr, test::AsTensor<int32>({0, 4}));
|
||||||
|
test::ExpectTensorEqual<int32>(csr_row_ptr,
|
||||||
|
test::AsTensor<int32>({0, 1, 1, 3, 4}));
|
||||||
|
test::ExpectTensorEqual<int32>(csr_col_ind,
|
||||||
|
test::AsTensor<int32>({0, 3, 4, 0}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(SparseTensorToCSRSparseMatrix, BatchConversion) {
|
||||||
|
// Batch of 3 matrices, each having dimension [3, 4] with 3 non-zero elements.
|
||||||
|
const auto indices = test::AsTensor<int64>({0, 0, 0, //
|
||||||
|
0, 2, 3, //
|
||||||
|
2, 0, 1},
|
||||||
|
TensorShape({3, 3}));
|
||||||
|
Tensor batch_ptr(DT_INT32, {4});
|
||||||
|
Tensor csr_col_ind(DT_INT32, {3});
|
||||||
|
// row pointers have size = batch_size * (num_rows + 1) = 3 * 4 = 12
|
||||||
|
Tensor csr_row_ptr(DT_INT32, {12});
|
||||||
|
test::FillFn<int32>(&csr_row_ptr, [](int unused) { return 0; });
|
||||||
|
|
||||||
|
functor::SparseTensorToCSRSparseMatrixCPUFunctor coo_to_csr;
|
||||||
|
TF_EXPECT_OK(coo_to_csr(3 /* batch_size */, 3 /* num_rows */,
|
||||||
|
indices.template matrix<int64>(),
|
||||||
|
batch_ptr.vec<int32>(), csr_row_ptr.vec<int32>(),
|
||||||
|
csr_col_ind.vec<int32>()));
|
||||||
|
|
||||||
|
test::ExpectTensorEqual<int32>(batch_ptr,
|
||||||
|
test::AsTensor<int32>({0, 2, 2, 3}));
|
||||||
|
test::ExpectTensorEqual<int32>(csr_row_ptr,
|
||||||
|
test::AsTensor<int32>({0, 1, 1, 2, //
|
||||||
|
0, 0, 0, 0, //
|
||||||
|
0, 1, 1, 1}));
|
||||||
|
test::ExpectTensorEqual<int32>(csr_col_ind, test::AsTensor<int32>({0, 3, 1}));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
int main(int argc, char** argv) {
|
||||||
|
::testing::InitGoogleTest(&argc, argv);
|
||||||
|
return RUN_ALL_TESTS();
|
||||||
|
}
|
436
tensorflow/core/kernels/sparse/mat_mul_op.cc
Normal file
436
tensorflow/core/kernels/sparse/mat_mul_op.cc
Normal file
@ -0,0 +1,436 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#define EIGEN_USE_THREADS
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
#define EIGEN_USE_GPU
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
|
#include "tensorflow/core/framework/op.h"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
|
#include "tensorflow/core/framework/variant_op_registry.h"
|
||||||
|
#include "tensorflow/core/kernels/dense_update_functor.h"
|
||||||
|
#include "tensorflow/core/kernels/fill_functor.h"
|
||||||
|
#include "tensorflow/core/kernels/sparse/kernels.h"
|
||||||
|
#include "tensorflow/core/kernels/sparse/sparse_matrix.h"
|
||||||
|
#include "tensorflow/core/kernels/sparse/transpose_op.h"
|
||||||
|
#include "tensorflow/core/kernels/transpose_functor.h"
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
#include "tensorflow/core/kernels/cuda_solvers.h"
|
||||||
|
#include "tensorflow/core/kernels/cuda_sparse.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||||
|
typedef Eigen::GpuDevice GPUDevice;
|
||||||
|
|
||||||
|
template <typename Device, typename T>
|
||||||
|
class CSRMatMulOp : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit CSRMatMulOp(OpKernelConstruction* c) : OpKernel(c) {
|
||||||
|
OP_REQUIRES_OK(c, c->GetAttr("transpose_a", &transpose_a_));
|
||||||
|
OP_REQUIRES_OK(c, c->GetAttr("transpose_b", &transpose_b_));
|
||||||
|
bool adjoint_a;
|
||||||
|
OP_REQUIRES_OK(c, c->GetAttr("adjoint_a", &adjoint_a));
|
||||||
|
OP_REQUIRES(c, !(adjoint_a && transpose_a_),
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"Only one of adjoint_a and transpose_a may be true."));
|
||||||
|
bool adjoint_b;
|
||||||
|
OP_REQUIRES_OK(c, c->GetAttr("adjoint_b", &adjoint_b));
|
||||||
|
OP_REQUIRES(c, !(adjoint_b && transpose_b_),
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"Only one of adjoint_b and transpose_b may be true."));
|
||||||
|
OP_REQUIRES_OK(c, c->GetAttr("transpose_output", &transpose_output_));
|
||||||
|
OP_REQUIRES_OK(c, c->GetAttr("conjugate_output", &conjugate_output_));
|
||||||
|
conjugate_a_ = adjoint_a;
|
||||||
|
conjugate_b_ = adjoint_b;
|
||||||
|
transpose_a_ = transpose_a_ || adjoint_a;
|
||||||
|
transpose_b_ = transpose_b_ || adjoint_b;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* ctx) final {
|
||||||
|
const CSRSparseMatrix* a_matrix;
|
||||||
|
OP_REQUIRES_OK(ctx, ExtractVariantFromInput(ctx, 0, &a_matrix));
|
||||||
|
const Tensor& b_t = ctx->input(1);
|
||||||
|
|
||||||
|
OP_REQUIRES(ctx, a_matrix->dtype() == b_t.dtype(),
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"Input types don't match. a.dtype == ",
|
||||||
|
DataTypeString(a_matrix->dtype()),
|
||||||
|
" vs. b.dtype == ", DataTypeString(b_t.dtype())));
|
||||||
|
|
||||||
|
const int a_rank = a_matrix->dims();
|
||||||
|
const int b_rank = b_t.dims();
|
||||||
|
const int64 batch_size = (b_rank == 2) ? 1 : b_t.dim_size(0);
|
||||||
|
|
||||||
|
// TODO(ebrevdo): Add support for broadcasting matmul.
|
||||||
|
OP_REQUIRES(ctx, a_rank == b_rank,
|
||||||
|
errors::InvalidArgument("Ranks of a and b must match, saw: ",
|
||||||
|
a_rank, " vs. ", b_rank, "."));
|
||||||
|
OP_REQUIRES(ctx, a_matrix->batch_size() == batch_size,
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"Batch sizes of a and b must match, saw: ",
|
||||||
|
a_matrix->batch_size(), " vs. ", batch_size, "."));
|
||||||
|
|
||||||
|
const Tensor& a_dense_shape_t = a_matrix->dense_shape();
|
||||||
|
TensorShape a_dense_tensor_shape;
|
||||||
|
auto a_dense_shape = a_dense_shape_t.vec<int64>();
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
ctx, TensorShapeUtils::MakeShape(a_dense_shape, &a_dense_tensor_shape));
|
||||||
|
|
||||||
|
const int row_dim = (a_rank == 2) ? 0 : 1;
|
||||||
|
const int64 a_inner_dim =
|
||||||
|
a_dense_tensor_shape.dim_size(transpose_a_ ? row_dim : row_dim + 1);
|
||||||
|
const int64 b_inner_dim =
|
||||||
|
b_t.shape().dim_size(transpose_b_ ? row_dim + 1 : row_dim);
|
||||||
|
const int64 b_outer_dim =
|
||||||
|
b_t.shape().dim_size(transpose_b_ ? row_dim : row_dim + 1);
|
||||||
|
const int64 b_slice_size = b_inner_dim * b_outer_dim;
|
||||||
|
|
||||||
|
OP_REQUIRES(
|
||||||
|
ctx, a_inner_dim == b_inner_dim,
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"Inner product dimensions of A and B do not agree. Shapes are: ",
|
||||||
|
a_dense_tensor_shape.DebugString(), " vs. ",
|
||||||
|
b_t.shape().DebugString()));
|
||||||
|
|
||||||
|
TensorShape c_shape;
|
||||||
|
if (a_rank == 3) c_shape.AddDim(batch_size);
|
||||||
|
if (transpose_output_) {
|
||||||
|
c_shape.AddDim(b_t.dim_size(transpose_b_ ? row_dim : row_dim + 1));
|
||||||
|
c_shape.AddDim(
|
||||||
|
a_dense_tensor_shape.dim_size(transpose_a_ ? row_dim + 1 : row_dim));
|
||||||
|
} else {
|
||||||
|
c_shape.AddDim(
|
||||||
|
a_dense_tensor_shape.dim_size(transpose_a_ ? row_dim + 1 : row_dim));
|
||||||
|
c_shape.AddDim(b_t.dim_size(transpose_b_ ? row_dim : row_dim + 1));
|
||||||
|
}
|
||||||
|
|
||||||
|
const int64 c_matrix_lhs = c_shape.dim_size(row_dim);
|
||||||
|
const int64 c_matrix_rhs = c_shape.dim_size(row_dim + 1);
|
||||||
|
const int64 c_slice_size = c_matrix_lhs * c_matrix_rhs;
|
||||||
|
Tensor* c_t;
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, c_shape, &c_t));
|
||||||
|
|
||||||
|
const Device& d = ctx->eigen_device<Device>();
|
||||||
|
|
||||||
|
if (b_outer_dim == 1) {
|
||||||
|
// Call matrix-vector multiply if b is a vector.
|
||||||
|
TTypes<int64>::ConstVec a_dense_shape_comp(a_dense_shape.data() + row_dim,
|
||||||
|
2);
|
||||||
|
Tensor b_conj_t;
|
||||||
|
const T* b_base_ptr = b_t.template flat<T>().data();
|
||||||
|
bool conjugate_a = conjugate_a_;
|
||||||
|
bool conjugate_output = conjugate_output_;
|
||||||
|
if (conjugate_b_) {
|
||||||
|
if (conjugate_a) {
|
||||||
|
// In this case we can use the identity
|
||||||
|
// conj(a) * conj(b) = conj(a * b)
|
||||||
|
// instead of creating a conjugated copy of b.
|
||||||
|
conjugate_a = false;
|
||||||
|
conjugate_output = !conjugate_output;
|
||||||
|
} else {
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
ctx, ctx->forward_input_or_allocate_temp(
|
||||||
|
{1}, DataTypeToEnum<T>::value, b_t.shape(), &b_conj_t));
|
||||||
|
functor::maybe_conj<Device, T>::run(d, b_t, &b_conj_t);
|
||||||
|
b_base_ptr = b_conj_t.template flat<T>().data();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
functor::CSRSparseMatrixMatVec<Device, T> csr_spmv(transpose_a_,
|
||||||
|
conjugate_a);
|
||||||
|
for (int i = 0; i < batch_size; ++i) {
|
||||||
|
auto a_row_ptr = a_matrix->row_pointers_vec(i);
|
||||||
|
auto a_col_ind = a_matrix->col_indices_vec(i);
|
||||||
|
auto a_values = a_matrix->values_vec<T>(i);
|
||||||
|
ConstCSRComponent<T> a_comp{a_row_ptr, a_col_ind, a_values,
|
||||||
|
a_dense_shape_comp};
|
||||||
|
const T* b_i = b_base_ptr + i * b_slice_size;
|
||||||
|
T* c_i = &c_t->template flat<T>()(i * c_slice_size);
|
||||||
|
Status s = csr_spmv.Compute(ctx, a_comp, b_i, c_i);
|
||||||
|
OP_REQUIRES_OK(ctx, s);
|
||||||
|
}
|
||||||
|
if (conjugate_output) {
|
||||||
|
functor::maybe_conj_inplace<Device, T>::run(d, c_t);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
functor::CSRSparseMatrixMatMul<Device, T> csr_spmmadd(transpose_output_);
|
||||||
|
|
||||||
|
Tensor c_mat_col_major_t;
|
||||||
|
if (!transpose_output_) {
|
||||||
|
// If transpose_output is false, we'll need to transpose the (col
|
||||||
|
// major) output of the csrgemm call to get proper (row-major)
|
||||||
|
// output. Which means we need to keep a temporary buffer to
|
||||||
|
// store the intermediate gemm output.
|
||||||
|
TensorShape c_mat_col_major_shape;
|
||||||
|
if (a_rank == 2) {
|
||||||
|
c_mat_col_major_shape = TensorShape({c_matrix_rhs, c_matrix_lhs});
|
||||||
|
} else {
|
||||||
|
c_mat_col_major_shape =
|
||||||
|
TensorShape({batch_size, c_matrix_rhs, c_matrix_lhs});
|
||||||
|
}
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
|
||||||
|
c_mat_col_major_shape, &c_mat_col_major_t));
|
||||||
|
}
|
||||||
|
|
||||||
|
// If transpose_output is true, return the direct (column-major i.e.,
|
||||||
|
// transposed) output of the csrgemm call. Otherwise we'll need
|
||||||
|
// to transpose it to row major format.
|
||||||
|
auto c_mat_col_major =
|
||||||
|
(transpose_output_) ? c_t->flat<T>() : c_mat_col_major_t.flat<T>();
|
||||||
|
|
||||||
|
// Possibly transpose a.
|
||||||
|
const CSRSparseMatrix* a_input_matrix;
|
||||||
|
// If we need to transpose a, we will store the result temporarily
|
||||||
|
// in the object below.
|
||||||
|
CSRSparseMatrix a_matrix_transposed;
|
||||||
|
if (!transpose_a_) {
|
||||||
|
a_input_matrix = a_matrix;
|
||||||
|
} else {
|
||||||
|
functor::CSRSparseMatrixTranspose<Device, T> transpose;
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
ctx, transpose(ctx, conjugate_a_, *a_matrix, &a_matrix_transposed));
|
||||||
|
a_input_matrix = &a_matrix_transposed;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto a_input_dense_shape = a_input_matrix->dense_shape().vec<int64>();
|
||||||
|
|
||||||
|
// Possibly transpose b.
|
||||||
|
Tensor b_t_input;
|
||||||
|
if (!transpose_b_) {
|
||||||
|
b_t_input = b_t;
|
||||||
|
} else {
|
||||||
|
TensorShape b_t_transposed_shape;
|
||||||
|
if (a_rank == 3) {
|
||||||
|
b_t_transposed_shape.AddDim(batch_size);
|
||||||
|
}
|
||||||
|
b_t_transposed_shape.AddDim(b_t.dim_size(row_dim + 1));
|
||||||
|
b_t_transposed_shape.AddDim(b_t.dim_size(row_dim));
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
|
||||||
|
b_t_transposed_shape, &b_t_input));
|
||||||
|
const Device& d = ctx->eigen_device<Device>();
|
||||||
|
if (conjugate_b_) {
|
||||||
|
OP_REQUIRES_OK(ctx, DoConjugateMatrixTranspose(d, b_t /*input*/,
|
||||||
|
&b_t_input /*output*/));
|
||||||
|
} else {
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
ctx, DoMatrixTranspose(d, b_t /*input*/, &b_t_input /*output*/));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dense shape of a batch component of A.
|
||||||
|
TTypes<int64>::ConstVec a_input_dense_shape_comp(
|
||||||
|
a_input_dense_shape.data() + row_dim, 2);
|
||||||
|
|
||||||
|
auto b = b_t_input.flat<T>();
|
||||||
|
|
||||||
|
for (int i = 0; i < batch_size; ++i) {
|
||||||
|
auto a_row_ptr = a_input_matrix->row_pointers_vec(i);
|
||||||
|
auto a_col_ind = a_input_matrix->col_indices_vec(i);
|
||||||
|
auto a_values = a_input_matrix->values_vec<T>(i);
|
||||||
|
typename TTypes<T>::UnalignedConstMatrix b_i(b.data() + i * b_slice_size,
|
||||||
|
{b_inner_dim, b_outer_dim});
|
||||||
|
typename TTypes<T>::UnalignedMatrix c_mat_col_major_i(
|
||||||
|
c_mat_col_major.data() + i * c_slice_size,
|
||||||
|
{c_matrix_lhs, c_matrix_rhs});
|
||||||
|
ConstCSRComponent<T> a_comp{a_row_ptr, a_col_ind, a_values,
|
||||||
|
a_input_dense_shape_comp};
|
||||||
|
Status s = csr_spmmadd.Compute(ctx, a_comp, b_i, c_mat_col_major_i);
|
||||||
|
OP_REQUIRES_OK(ctx, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!transpose_output_) {
|
||||||
|
// We need to return values in row major format, so transpose
|
||||||
|
// the column-major values in c_mat_col_major_t to row-major output c_t.
|
||||||
|
OP_REQUIRES_OK(ctx, DoMatrixTranspose(d, /*input=*/c_mat_col_major_t,
|
||||||
|
/*output=*/c_t));
|
||||||
|
}
|
||||||
|
if (conjugate_output_) {
|
||||||
|
functor::maybe_conj_inplace<Device, T>::run(d, c_t);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool transpose_a_;
|
||||||
|
bool transpose_b_;
|
||||||
|
bool conjugate_a_;
|
||||||
|
bool conjugate_b_;
|
||||||
|
bool transpose_output_;
|
||||||
|
bool conjugate_output_;
|
||||||
|
};
|
||||||
|
|
||||||
|
#define REGISTER(DEV, T) \
|
||||||
|
REGISTER_KERNEL_BUILDER( \
|
||||||
|
Name("SparseMatrixMatMul").Device(DEVICE_##DEV).TypeConstraint<T>("T"), \
|
||||||
|
CSRMatMulOp<DEV##Device, T>);
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
|
||||||
|
#define REGISTER_GPU(T) REGISTER(GPU, T)
|
||||||
|
|
||||||
|
REGISTER_GPU(float)
|
||||||
|
REGISTER_GPU(double)
|
||||||
|
REGISTER_GPU(complex64)
|
||||||
|
REGISTER_GPU(complex128)
|
||||||
|
|
||||||
|
#undef REGISTER_GPU
|
||||||
|
|
||||||
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
|
#undef REGISTER
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
|
||||||
|
namespace functor {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class CSRSparseMatrixMatMul<GPUDevice, T> {
|
||||||
|
public:
|
||||||
|
explicit CSRSparseMatrixMatMul(const bool transpose_output)
|
||||||
|
: transpose_output_(transpose_output) {}
|
||||||
|
|
||||||
|
Status Compute(OpKernelContext* ctx, const ConstCSRComponent<T>& a,
|
||||||
|
typename TTypes<T>::UnalignedConstMatrix b,
|
||||||
|
typename TTypes<T>::UnalignedMatrix c) {
|
||||||
|
CudaSparse cuda_sparse(ctx);
|
||||||
|
TF_RETURN_IF_ERROR(cuda_sparse.Initialize());
|
||||||
|
{
|
||||||
|
// Use Csrmm to calculate:
|
||||||
|
// C = alpha * op(A) * op(B) + beta * C
|
||||||
|
// where alpha = 1.0, beta = 0.0, A is sparse and B and C are dense.
|
||||||
|
// Note that Csrmm assumes B and C are in column-major form; so we
|
||||||
|
// use transB == true, and manually transpose the output in place
|
||||||
|
// using blas<t>geam.
|
||||||
|
// TODO(ebrevdo,rmlarsen): Add support for transposition and adjoint.
|
||||||
|
|
||||||
|
// Create alpha and beta scalars; alpha = 1.0, beta = 0.0
|
||||||
|
// TODO(ebrevdo,rmlarsen): Add support for non-trivial alpha and beta.
|
||||||
|
const T alpha = 1;
|
||||||
|
const T beta = 0;
|
||||||
|
|
||||||
|
// transA must be non-transpose if transB is transpose (cusparse
|
||||||
|
// limitation).
|
||||||
|
const cusparseOperation_t transA = CUSPARSE_OPERATION_NON_TRANSPOSE;
|
||||||
|
|
||||||
|
// transB: b is row-major, and cusparse requires col-major b (or
|
||||||
|
// equivalently transB == transpose). this version is actually more
|
||||||
|
// efficient.
|
||||||
|
const cusparseOperation_t transB = CUSPARSE_OPERATION_TRANSPOSE;
|
||||||
|
|
||||||
|
cusparseMatDescr_t descrA;
|
||||||
|
TF_RETURN_IF_CUSPARSE_ERROR(cusparseCreateMatDescr(&descrA));
|
||||||
|
TF_RETURN_IF_CUSPARSE_ERROR(
|
||||||
|
cusparseSetMatType(descrA, CUSPARSE_MATRIX_TYPE_GENERAL));
|
||||||
|
TF_RETURN_IF_CUSPARSE_ERROR(
|
||||||
|
cusparseSetMatIndexBase(descrA, CUSPARSE_INDEX_BASE_ZERO));
|
||||||
|
|
||||||
|
// A is (m, k), Bt is (ldb, k) and Ct is (ldc, n)
|
||||||
|
const int k = b.dimension(0);
|
||||||
|
DCHECK_EQ(k, a.dense_shape_host(1));
|
||||||
|
|
||||||
|
// If transpose_output_ is true, then the c matrix we receive
|
||||||
|
// here is the direct row major output (into which we will store
|
||||||
|
// csrgemm's col major output). Otherwise it's a
|
||||||
|
// temporary tensor that will store the column major output that
|
||||||
|
// will eventually be transposed.
|
||||||
|
const int m = c.dimension(transpose_output_ ? 1 : 0);
|
||||||
|
const int n = c.dimension(transpose_output_ ? 0 : 1);
|
||||||
|
DCHECK_EQ(m, a.dense_shape_host(0));
|
||||||
|
DCHECK_EQ(n, b.dimension(1));
|
||||||
|
const int nnz = a.values.size();
|
||||||
|
DCHECK_EQ(nnz, a.col_ind.size());
|
||||||
|
|
||||||
|
// ldb: leading dimension of B. If op(B)=B, it must be at least max(1, k)
|
||||||
|
// if op(A) = A and at least max (1, m) otherwise. If op(B) != B, it must
|
||||||
|
// be at least max(1, n).
|
||||||
|
const int ldb = n;
|
||||||
|
// ldc: leading dimension of C. It must be at least max(1, m) if
|
||||||
|
// op(A) = A and at least max(1, k) otherwise.
|
||||||
|
const int ldc = m;
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
cuda_sparse.Csrmm(transA, transB, m, n, k, nnz, &alpha, descrA,
|
||||||
|
a.values.data(), a.row_ptr.data(), a.col_ind.data(),
|
||||||
|
b.data(), ldb, &beta, c.data(), ldc));
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool transpose_output_;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class CSRSparseMatrixMatVec<GPUDevice, T> {
|
||||||
|
public:
|
||||||
|
CSRSparseMatrixMatVec(bool transpose_a, bool conjugate_a)
|
||||||
|
: transA_(TransposeAndConjugateToCuSparseOp(transpose_a, conjugate_a,
|
||||||
|
&status_)) {}
|
||||||
|
|
||||||
|
Status Compute(OpKernelContext* ctx, const ConstCSRComponent<T>& a,
|
||||||
|
const T* x, T* y) {
|
||||||
|
TF_RETURN_IF_ERROR(status_);
|
||||||
|
CudaSparse cuda_sparse(ctx);
|
||||||
|
TF_RETURN_IF_ERROR(cuda_sparse.Initialize());
|
||||||
|
{
|
||||||
|
// Use Csrmv to calculate:
|
||||||
|
// y = alpha * op(A) * x + beta * y
|
||||||
|
// where alpha = 1.0, beta = 0.0, A is a sparse matrix and x and y are
|
||||||
|
// dense vectors.
|
||||||
|
|
||||||
|
// Create alpha and beta scalars; alpha = 1.0, beta = 0.0
|
||||||
|
// TODO(rmlarsen,ebrevdo): Add support for general alpha, beta.
|
||||||
|
const T alpha = 1;
|
||||||
|
const T beta = 0;
|
||||||
|
|
||||||
|
cusparseMatDescr_t descrA;
|
||||||
|
TF_RETURN_IF_CUSPARSE_ERROR(cusparseCreateMatDescr(&descrA));
|
||||||
|
TF_RETURN_IF_CUSPARSE_ERROR(
|
||||||
|
cusparseSetMatType(descrA, CUSPARSE_MATRIX_TYPE_GENERAL));
|
||||||
|
TF_RETURN_IF_CUSPARSE_ERROR(
|
||||||
|
cusparseSetMatIndexBase(descrA, CUSPARSE_INDEX_BASE_ZERO));
|
||||||
|
|
||||||
|
const int m = a.dense_shape_host(0);
|
||||||
|
const int n = a.dense_shape_host(1);
|
||||||
|
const int nnz = a.values.size();
|
||||||
|
DCHECK_EQ(nnz, a.col_ind.size());
|
||||||
|
TF_RETURN_IF_ERROR(cuda_sparse.Csrmv(transA_, m, n, nnz, &alpha, descrA,
|
||||||
|
a.values.data(), a.row_ptr.data(),
|
||||||
|
a.col_ind.data(), x, &beta, y));
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
Status status_;
|
||||||
|
const cusparseOperation_t transA_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace functor
|
||||||
|
|
||||||
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
171
tensorflow/core/kernels/sparse/mul_op.cc
Normal file
171
tensorflow/core/kernels/sparse/mul_op.cc
Normal file
@ -0,0 +1,171 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#define EIGEN_USE_THREADS
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
#define EIGEN_USE_GPU
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
|
#include "tensorflow/core/framework/op.h"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
|
#include "tensorflow/core/framework/variant_op_registry.h"
|
||||||
|
#include "tensorflow/core/kernels/cwise_ops.h"
|
||||||
|
#include "tensorflow/core/kernels/sparse/kernels.h"
|
||||||
|
#include "tensorflow/core/kernels/sparse/sparse_matrix.h"
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
#include "tensorflow/core/kernels/cuda_sparse.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||||
|
typedef Eigen::GpuDevice GPUDevice;
|
||||||
|
|
||||||
|
template <typename Device, typename T>
|
||||||
|
class CSRMulOp : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit CSRMulOp(OpKernelConstruction* c) : OpKernel(c) {}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* ctx) final {
|
||||||
|
const CSRSparseMatrix* a_matrix;
|
||||||
|
OP_REQUIRES_OK(ctx, ExtractVariantFromInput(ctx, 0, &a_matrix));
|
||||||
|
const Tensor& b_t = ctx->input(1);
|
||||||
|
|
||||||
|
OP_REQUIRES(ctx, a_matrix->dtype() == b_t.dtype(),
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"Input types don't match. a.dtype == ",
|
||||||
|
DataTypeString(a_matrix->dtype()),
|
||||||
|
" vs. b.dtype == ", DataTypeString(b_t.dtype())));
|
||||||
|
|
||||||
|
const int b_rank = b_t.dims();
|
||||||
|
|
||||||
|
const Tensor& a_dense_shape_t = a_matrix->dense_shape();
|
||||||
|
auto a_dense_shape = a_dense_shape_t.vec<int64>();
|
||||||
|
const int batch_size = a_dense_shape(0);
|
||||||
|
if (b_rank == 3) {
|
||||||
|
OP_REQUIRES(
|
||||||
|
ctx,
|
||||||
|
((a_matrix->dims() == 3) && (b_t.dim_size(0) == batch_size) &&
|
||||||
|
(b_t.NumElements() == batch_size)),
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"If b is a rank-3 tensor, then a must be a rank 3 and the size "
|
||||||
|
"of b be "
|
||||||
|
"[batch_size, 1, 1]. But the shape of b is: ",
|
||||||
|
b_t.shape().DebugString(),
|
||||||
|
" and the shape of a is: ", a_dense_shape_t.DebugString()));
|
||||||
|
} else {
|
||||||
|
OP_REQUIRES(ctx, b_rank == 0,
|
||||||
|
errors::Unimplemented(
|
||||||
|
"Multiplying by a 2D+ dense tensor is not currently "
|
||||||
|
"supported, but shape of b is: ",
|
||||||
|
b_t.shape().DebugString()));
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor c_t(cpu_allocator(), DT_VARIANT, TensorShape({}));
|
||||||
|
CSRSparseMatrix c_matrix;
|
||||||
|
if (b_rank == 0) {
|
||||||
|
auto b = b_t.scalar<T>();
|
||||||
|
// TODO(ebrevdo): call other functor if b is nonscalar.
|
||||||
|
functor::CSRSparseMatrixMulScalar<Device, T> csrmul_scalar;
|
||||||
|
OP_REQUIRES_OK(ctx, csrmul_scalar.Compute(ctx, *a_matrix, b, &c_matrix));
|
||||||
|
} else {
|
||||||
|
// b_rank == 1 and a_matrix is rank-3.
|
||||||
|
auto b = b_t.flat<T>();
|
||||||
|
functor::CSRSparseMatrixBatchMulVec<Device, T> csrmul_batch_vec;
|
||||||
|
OP_REQUIRES_OK(ctx,
|
||||||
|
csrmul_batch_vec.Compute(ctx, *a_matrix, b, &c_matrix));
|
||||||
|
}
|
||||||
|
c_t.scalar<Variant>()() = std::move(c_matrix);
|
||||||
|
ctx->set_output(0, c_t);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
#define REGISTER(DEV, T) \
|
||||||
|
REGISTER_KERNEL_BUILDER( \
|
||||||
|
Name("SparseMatrixMul").Device(DEVICE_##DEV).TypeConstraint<T>("T"), \
|
||||||
|
CSRMulOp<DEV##Device, T>);
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
|
||||||
|
#define REGISTER_GPU(T) REGISTER(GPU, T)
|
||||||
|
|
||||||
|
REGISTER_GPU(float)
|
||||||
|
REGISTER_GPU(double)
|
||||||
|
REGISTER_GPU(complex64)
|
||||||
|
REGISTER_GPU(complex128)
|
||||||
|
|
||||||
|
#undef REGISTER_GPU
|
||||||
|
|
||||||
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
|
#undef REGISTER
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
|
||||||
|
namespace functor {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class CSRSparseMatrixMulScalar<GPUDevice, T> {
|
||||||
|
public:
|
||||||
|
explicit CSRSparseMatrixMulScalar() {}
|
||||||
|
|
||||||
|
Status Compute(OpKernelContext* ctx, const CSRSparseMatrix& a,
|
||||||
|
typename TTypes<T>::ConstScalar b, CSRSparseMatrix* c) {
|
||||||
|
const int total_nnz = a.total_nnz();
|
||||||
|
Tensor c_values_t;
|
||||||
|
TF_RETURN_IF_ERROR(ctx->allocate_temp(
|
||||||
|
DataTypeToEnum<T>::value, TensorShape({total_nnz}), &c_values_t));
|
||||||
|
TF_RETURN_IF_ERROR(CSRSparseMatrix::CreateCSRSparseMatrix(
|
||||||
|
DataTypeToEnum<T>::value, a.dense_shape(), a.batch_pointers(),
|
||||||
|
a.row_pointers(), a.col_indices(), c_values_t, c));
|
||||||
|
|
||||||
|
auto a_values = a.values().flat<T>();
|
||||||
|
auto c_values = c_values_t.flat<T>();
|
||||||
|
|
||||||
|
const GPUDevice& d = ctx->eigen_device<GPUDevice>();
|
||||||
|
bool error;
|
||||||
|
bool* const error_ptr = functor::mul<T>::has_errors ? &error : nullptr;
|
||||||
|
|
||||||
|
// tensor * scalar
|
||||||
|
functor::BinaryFunctor<GPUDevice, functor::mul<T>, 1>().Right(
|
||||||
|
d, c_values, a_values, b, error_ptr);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
#define DECLARE_GPU_SPEC(T) \
|
||||||
|
template <> \
|
||||||
|
Status CSRSparseMatrixBatchMulVec<GPUDevice, T>::Compute( \
|
||||||
|
OpKernelContext* ctx, const CSRSparseMatrix& a, \
|
||||||
|
typename TTypes<T>::ConstFlat b, CSRSparseMatrix* c); \
|
||||||
|
extern template struct CSRSparseMatrixBatchMulVec<GPUDevice, T>;
|
||||||
|
|
||||||
|
DECLARE_GPU_SPEC(float);
|
||||||
|
DECLARE_GPU_SPEC(double);
|
||||||
|
DECLARE_GPU_SPEC(std::complex<float>);
|
||||||
|
DECLARE_GPU_SPEC(std::complex<double>);
|
||||||
|
|
||||||
|
#undef DECLARE_GPU_SPEC
|
||||||
|
|
||||||
|
} // namespace functor
|
||||||
|
|
||||||
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
78
tensorflow/core/kernels/sparse/nnz_op.cc
Normal file
78
tensorflow/core/kernels/sparse/nnz_op.cc
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#define EIGEN_USE_THREADS
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
#define EIGEN_USE_GPU
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
|
#include "tensorflow/core/framework/op.h"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
|
#include "tensorflow/core/framework/variant_op_registry.h"
|
||||||
|
#include "tensorflow/core/kernels/dense_update_functor.h"
|
||||||
|
#include "tensorflow/core/kernels/sparse/kernels.h"
|
||||||
|
#include "tensorflow/core/kernels/sparse/sparse_matrix.h"
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
#include "tensorflow/core/kernels/cuda_solvers.h"
|
||||||
|
#include "tensorflow/core/kernels/cuda_sparse.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||||
|
typedef Eigen::GpuDevice GPUDevice;
|
||||||
|
|
||||||
|
template <typename Device>
|
||||||
|
class CSRNNZOp : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit CSRNNZOp(OpKernelConstruction* c) : OpKernel(c) {}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* c) final {
|
||||||
|
const CSRSparseMatrix* csr_sparse_matrix;
|
||||||
|
OP_REQUIRES_OK(c, ExtractVariantFromInput(c, 0, &csr_sparse_matrix));
|
||||||
|
Tensor* nnz_t;
|
||||||
|
TensorShape nnz_shape;
|
||||||
|
if (csr_sparse_matrix->dims() == 3) {
|
||||||
|
nnz_shape.AddDim(csr_sparse_matrix->batch_size());
|
||||||
|
}
|
||||||
|
OP_REQUIRES_OK(c, c->allocate_output(0, nnz_shape, &nnz_t));
|
||||||
|
auto nnz = nnz_t->flat<int32>();
|
||||||
|
for (int i = 0; i < csr_sparse_matrix->batch_size(); ++i) {
|
||||||
|
nnz(i) = csr_sparse_matrix->nnz(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
#define REGISTER(DEV) \
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("SparseMatrixNNZ") \
|
||||||
|
.Device(DEVICE_##DEV) \
|
||||||
|
.HostMemory("nnz"), \
|
||||||
|
CSRNNZOp<DEV##Device>);
|
||||||
|
|
||||||
|
REGISTER(CPU)
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
|
||||||
|
REGISTER(GPU)
|
||||||
|
|
||||||
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
|
#undef REGISTER
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
225
tensorflow/core/kernels/sparse/softmax_op.cc
Normal file
225
tensorflow/core/kernels/sparse/softmax_op.cc
Normal file
@ -0,0 +1,225 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
// Implements the kernel for the CSRSoftmax op, which performs softmax
|
||||||
|
// along the innermost (col) dimension of a CSRSparseMatrix object
|
||||||
|
// stored in a DT_VARIANT.
|
||||||
|
|
||||||
|
#define EIGEN_USE_THREADS
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
#include "tensorflow/core/kernels/cuda_sparse.h"
|
||||||
|
#define EIGEN_USE_GPU
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
|
#include "tensorflow/core/framework/op.h"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
|
#include "tensorflow/core/framework/variant_op_registry.h"
|
||||||
|
#include "tensorflow/core/kernels/dense_update_functor.h"
|
||||||
|
#include "tensorflow/core/kernels/fill_functor.h"
|
||||||
|
#include "tensorflow/core/kernels/slice_op.h"
|
||||||
|
#include "tensorflow/core/kernels/sparse/kernels.h"
|
||||||
|
#include "tensorflow/core/kernels/sparse/sparse_matrix.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||||
|
typedef Eigen::GpuDevice GPUDevice;
|
||||||
|
|
||||||
|
template <typename Device, typename T>
|
||||||
|
class CSRSoftmaxOp : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit CSRSoftmaxOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* ctx) override {
|
||||||
|
const CSRSparseMatrix* logits_matrix;
|
||||||
|
OP_REQUIRES_OK(ctx, ExtractVariantFromInput(ctx, 0, &logits_matrix));
|
||||||
|
OP_REQUIRES(
|
||||||
|
ctx, logits_matrix->dtype() == DataTypeToEnum<T>::value,
|
||||||
|
errors::InvalidArgument("dtype of logits is not equal to 'type': ",
|
||||||
|
DataTypeString(logits_matrix->dtype()), " vs. ",
|
||||||
|
DataTypeString(DataTypeToEnum<T>::value)));
|
||||||
|
|
||||||
|
// Allocate output shapes
|
||||||
|
const int total_nnz = logits_matrix->total_nnz();
|
||||||
|
Tensor output_values_t;
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
|
||||||
|
TensorShape({total_nnz}), &output_values_t));
|
||||||
|
|
||||||
|
CSRSparseMatrix output_matrix;
|
||||||
|
|
||||||
|
Tensor dense_shape_t = logits_matrix->dense_shape();
|
||||||
|
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
ctx,
|
||||||
|
CSRSparseMatrix::CreateCSRSparseMatrix(
|
||||||
|
DataTypeToEnum<T>::value, dense_shape_t,
|
||||||
|
logits_matrix->batch_pointers(), logits_matrix->row_pointers(),
|
||||||
|
logits_matrix->col_indices(), output_values_t, &output_matrix));
|
||||||
|
|
||||||
|
if (total_nnz > 0) {
|
||||||
|
functor::CSRSparseMatrixSoftmax<Device, T> softmax;
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
ctx, softmax(ctx, *logits_matrix, output_matrix.values().vec<T>()));
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor output_t(cpu_allocator(), DT_VARIANT, TensorShape({}));
|
||||||
|
output_t.scalar<Variant>()() = std::move(output_matrix);
|
||||||
|
ctx->set_output(0, output_t);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
#ifdef GOOGLE_CUDA
|
||||||
|
#define REGISTER(DEV, T) \
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("SparseMatrixSoftmax") \
|
||||||
|
.Device(DEVICE_##DEV) \
|
||||||
|
.TypeConstraint<T>("type"), \
|
||||||
|
CSRSoftmaxOp<DEV##Device, T>);
|
||||||
|
|
||||||
|
REGISTER(GPU, float)
|
||||||
|
REGISTER(GPU, double)
|
||||||
|
|
||||||
|
#undef REGISTER
|
||||||
|
|
||||||
|
namespace functor {
|
||||||
|
#define DECLARE_GPU_SPEC(T) \
|
||||||
|
template <> \
|
||||||
|
Status CSRSparseMatrixSoftmax<GPUDevice, T>::operator()( \
|
||||||
|
OpKernelContext* ctx, const CSRSparseMatrix& logits, \
|
||||||
|
typename TTypes<T>::Vec softmax_values); \
|
||||||
|
extern template struct CSRSparseMatrixSoftmax<GPUDevice, T>;
|
||||||
|
|
||||||
|
DECLARE_GPU_SPEC(float);
|
||||||
|
DECLARE_GPU_SPEC(double);
|
||||||
|
|
||||||
|
#undef DECLARE_GPU_SPEC
|
||||||
|
} // namespace functor
|
||||||
|
|
||||||
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
|
template <typename Device, typename T>
|
||||||
|
class CSRSoftmaxGradOp : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit CSRSoftmaxGradOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* ctx) override {
|
||||||
|
const CSRSparseMatrix* softmax_matrix;
|
||||||
|
OP_REQUIRES_OK(ctx, ExtractVariantFromInput(ctx, 0, &softmax_matrix));
|
||||||
|
OP_REQUIRES(ctx, softmax_matrix->dtype() == DataTypeToEnum<T>::value,
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"dtype of softmax is not equal to 'type': ",
|
||||||
|
DataTypeString(softmax_matrix->dtype()), " vs. ",
|
||||||
|
DataTypeString(DataTypeToEnum<T>::value)));
|
||||||
|
|
||||||
|
const CSRSparseMatrix* grad_softmax_matrix;
|
||||||
|
OP_REQUIRES_OK(ctx, ExtractVariantFromInput(ctx, 1, &grad_softmax_matrix));
|
||||||
|
OP_REQUIRES(ctx, grad_softmax_matrix->dtype() == DataTypeToEnum<T>::value,
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"dtype of grad_softmax is not equal to 'type': ",
|
||||||
|
DataTypeString(grad_softmax_matrix->dtype()), " vs. ",
|
||||||
|
DataTypeString(DataTypeToEnum<T>::value)));
|
||||||
|
|
||||||
|
OP_REQUIRES(
|
||||||
|
ctx, softmax_matrix->dims() == grad_softmax_matrix->dims(),
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"Ranks of softmax and grad_softmax matrices differ: ",
|
||||||
|
softmax_matrix->dims(), " vs. ", grad_softmax_matrix->dims()));
|
||||||
|
|
||||||
|
OP_REQUIRES(
|
||||||
|
ctx, softmax_matrix->dims() == grad_softmax_matrix->dims(),
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"Ranks of softmax and grad_softmax matrices differ: ",
|
||||||
|
softmax_matrix->dims(), " vs. ", grad_softmax_matrix->dims()));
|
||||||
|
|
||||||
|
Tensor dense_shape_t = softmax_matrix->dense_shape();
|
||||||
|
auto host_dense_shape =
|
||||||
|
static_cast<const Tensor>(dense_shape_t).vec<int64>();
|
||||||
|
|
||||||
|
auto host_grad_dense_shape =
|
||||||
|
grad_softmax_matrix->dense_shape().vec<int64>();
|
||||||
|
|
||||||
|
for (int i = 0; i < host_dense_shape.size(); ++i) {
|
||||||
|
OP_REQUIRES(ctx, host_dense_shape(i) == host_grad_dense_shape(i),
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"Shapes of softmax and grad_softmax matrices differ: ",
|
||||||
|
dense_shape_t.SummarizeValue(3), " vs. ",
|
||||||
|
grad_softmax_matrix->dense_shape().SummarizeValue(3)));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allocate output shapes. Note that since the Softmax Gradient
|
||||||
|
// tensor is the elementwise product of some function with the
|
||||||
|
// softmax value, it will keep the sparsity structure of the softmax.
|
||||||
|
const int total_nnz = softmax_matrix->total_nnz();
|
||||||
|
PersistentTensor gradient_values_pt;
|
||||||
|
Tensor* gradient_values_t;
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->allocate_persistent(
|
||||||
|
DataTypeToEnum<T>::value, TensorShape({total_nnz}),
|
||||||
|
&gradient_values_pt, &gradient_values_t));
|
||||||
|
|
||||||
|
CSRSparseMatrix gradient_matrix;
|
||||||
|
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
ctx, CSRSparseMatrix::CreateCSRSparseMatrix(
|
||||||
|
DataTypeToEnum<T>::value, dense_shape_t,
|
||||||
|
softmax_matrix->batch_pointers(),
|
||||||
|
softmax_matrix->row_pointers(), softmax_matrix->col_indices(),
|
||||||
|
*gradient_values_t, &gradient_matrix));
|
||||||
|
|
||||||
|
if (total_nnz > 0) {
|
||||||
|
functor::CSRSparseMatrixSoftmaxGrad<Device, T> softmax_grad;
|
||||||
|
OP_REQUIRES_OK(ctx,
|
||||||
|
softmax_grad(ctx, *softmax_matrix, *grad_softmax_matrix,
|
||||||
|
gradient_matrix.values().vec<T>()));
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor gradient_t(cpu_allocator(), DT_VARIANT, TensorShape({}));
|
||||||
|
gradient_t.scalar<Variant>()() = std::move(gradient_matrix);
|
||||||
|
ctx->set_output(0, gradient_t);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
#ifdef GOOGLE_CUDA
|
||||||
|
#define REGISTER(DEV, T) \
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("SparseMatrixSoftmaxGrad") \
|
||||||
|
.Device(DEVICE_##DEV) \
|
||||||
|
.TypeConstraint<T>("type"), \
|
||||||
|
CSRSoftmaxGradOp<DEV##Device, T>);
|
||||||
|
|
||||||
|
REGISTER(GPU, float)
|
||||||
|
REGISTER(GPU, double)
|
||||||
|
|
||||||
|
#undef REGISTER
|
||||||
|
|
||||||
|
namespace functor {
|
||||||
|
#define DECLARE_GPU_SPEC(T) \
|
||||||
|
template <> \
|
||||||
|
Status CSRSparseMatrixSoftmaxGrad<GPUDevice, T>::operator()( \
|
||||||
|
OpKernelContext* ctx, const CSRSparseMatrix& softmax, \
|
||||||
|
const CSRSparseMatrix& grad_softmax, \
|
||||||
|
typename TTypes<T>::Vec gradient_values); \
|
||||||
|
extern template struct CSRSparseMatrixSoftmaxGrad<GPUDevice, T>;
|
||||||
|
|
||||||
|
DECLARE_GPU_SPEC(float);
|
||||||
|
DECLARE_GPU_SPEC(double);
|
||||||
|
|
||||||
|
#undef DECLARE_GPU_SPEC
|
||||||
|
} // namespace functor
|
||||||
|
|
||||||
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
288
tensorflow/core/kernels/sparse/sparse_cholesky_op.cc
Normal file
288
tensorflow/core/kernels/sparse/sparse_cholesky_op.cc
Normal file
@ -0,0 +1,288 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include <atomic>
|
||||||
|
#include <numeric>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#define EIGEN_USE_THREADS
|
||||||
|
|
||||||
|
#include "third_party/eigen3/Eigen/Core"
|
||||||
|
#include "third_party/eigen3/Eigen/SparseCholesky"
|
||||||
|
#include "third_party/eigen3/Eigen/SparseCore"
|
||||||
|
#include "third_party/eigen3/Eigen/OrderingMethods"
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
|
#include "tensorflow/core/framework/op.h"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
|
#include "tensorflow/core/framework/variant_op_registry.h"
|
||||||
|
#include "tensorflow/core/kernels/sparse/kernels.h"
|
||||||
|
#include "tensorflow/core/kernels/sparse/sparse_matrix.h"
|
||||||
|
#include "tensorflow/core/util/work_sharder.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// Op to compute the sparse Cholesky factorization of a sparse matrix.
|
||||||
|
//
|
||||||
|
// Implements a CPU kernel which returns the lower triangular sparse Cholesky
|
||||||
|
// factor of a CSRSparseMatrix, using the fill-in reducing permutation.
|
||||||
|
//
|
||||||
|
// The CSRSparseMatrix may represent a single sparse matrix (rank 2) or a batch
|
||||||
|
// of sparse matrices (rank 3). Each component must represent a symmetric
|
||||||
|
// positive definite (SPD) matrix. In particular, this means the component
|
||||||
|
// matrices must be square. We don't actually check if the input is symmetric,
|
||||||
|
// only the lower triangular part of each component is read.
|
||||||
|
//
|
||||||
|
// The associated permutation must be a Tensor of rank (R - 1), where the
|
||||||
|
// CSRSparseMatrix has rank R. Additionally, the batch dimension of the
|
||||||
|
// CSRSparseMatrix and the permutation must be the same. Each batch of
|
||||||
|
// the permutation should the contain each of the integers [0,..,N - 1] exactly
|
||||||
|
// once, where N is the number of rows of each CSR SparseMatrix component.
|
||||||
|
// TODO(anudhyan): Add checks to throw an InvalidArgument error if the
|
||||||
|
// permutation is not valid.
|
||||||
|
//
|
||||||
|
// Returns a CSRSparseMatrix representing the lower triangular (batched)
|
||||||
|
// Cholesky factors. It has the same shape as the input CSRSparseMatrix. For
|
||||||
|
// each component sparse matrix A, the corresponding output sparse matrix L
|
||||||
|
// satisfies the identity:
|
||||||
|
// A = L * Lt
|
||||||
|
// where Lt denotes the adjoint of L.
|
||||||
|
//
|
||||||
|
// TODO(b/126472741): Due to the multiple batches of a 3D CSRSparseMatrix being
|
||||||
|
// laid out in contiguous memory, this implementation allocates memory to store
|
||||||
|
// a temporary copy of the Cholesky factor. Consequently, it uses roughly twice
|
||||||
|
// the amount of memory that it needs to. This may cause a memory blowup for
|
||||||
|
// sparse matrices with a high number of non-zero elements.
|
||||||
|
template <typename T>
|
||||||
|
class CSRSparseCholeskyCPUOp : public OpKernel {
|
||||||
|
// Note: We operate in column major (CSC) format in this Op since the
|
||||||
|
// SimplicialLLT returns the factor in column major.
|
||||||
|
using SparseMatrix = Eigen::SparseMatrix<T, Eigen::ColMajor>;
|
||||||
|
|
||||||
|
public:
|
||||||
|
explicit CSRSparseCholeskyCPUOp(OpKernelConstruction* c) : OpKernel(c) {}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* ctx) final {
|
||||||
|
// Extract inputs and valididate shapes and types.
|
||||||
|
const CSRSparseMatrix* input_matrix;
|
||||||
|
OP_REQUIRES_OK(ctx, ExtractVariantFromInput(ctx, 0, &input_matrix));
|
||||||
|
const Tensor& input_permutation_indices = ctx->input(1);
|
||||||
|
|
||||||
|
int64 num_rows;
|
||||||
|
int batch_size;
|
||||||
|
ValidateInputs(ctx, *input_matrix, input_permutation_indices, &batch_size,
|
||||||
|
&num_rows);
|
||||||
|
|
||||||
|
// Allocate batch pointers.
|
||||||
|
Tensor batch_ptr(cpu_allocator(), DT_INT32, TensorShape({batch_size + 1}));
|
||||||
|
auto batch_ptr_vec = batch_ptr.vec<int32>();
|
||||||
|
batch_ptr_vec(0) = 0;
|
||||||
|
|
||||||
|
// Temporary vector of Eigen SparseMatrices to store the Sparse Cholesky
|
||||||
|
// factors.
|
||||||
|
// Note: we use column-compressed (CSC) SparseMatrix because SimplicialLLT
|
||||||
|
// returns the factors in column major format. Since our input should be
|
||||||
|
// symmetric, column major and row major is identical in storage. We just
|
||||||
|
// have to switch to reading the upper triangular part of the input, which
|
||||||
|
// corresponds to the lower triangular part in row major format.
|
||||||
|
std::vector<SparseMatrix> sparse_cholesky_factors(batch_size);
|
||||||
|
|
||||||
|
// TODO(anudhyan): Tune the cost per unit based on benchmarks.
|
||||||
|
const double nnz_per_row =
|
||||||
|
(input_matrix->total_nnz() / batch_size) / num_rows;
|
||||||
|
const int64 sparse_cholesky_cost_per_batch =
|
||||||
|
nnz_per_row * nnz_per_row * num_rows;
|
||||||
|
// Perform sparse Cholesky factorization of each batch in parallel.
|
||||||
|
auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
|
||||||
|
std::atomic<int64> invalid_input_index(-1);
|
||||||
|
Shard(worker_threads.num_threads, worker_threads.workers, batch_size,
|
||||||
|
sparse_cholesky_cost_per_batch,
|
||||||
|
[&](int64 batch_begin, int64 batch_end) {
|
||||||
|
for (int64 batch_index = batch_begin; batch_index < batch_end;
|
||||||
|
++batch_index) {
|
||||||
|
// Define an Eigen SparseMatrix Map to operate on the
|
||||||
|
// CSRSparseMatrix component without copying the data.
|
||||||
|
Eigen::Map<const SparseMatrix> sparse_matrix(
|
||||||
|
num_rows, num_rows, input_matrix->nnz(batch_index),
|
||||||
|
input_matrix->row_pointers_vec(batch_index).data(),
|
||||||
|
input_matrix->col_indices_vec(batch_index).data(),
|
||||||
|
input_matrix->values_vec<T>(batch_index).data());
|
||||||
|
|
||||||
|
Eigen::SimplicialLLT<SparseMatrix, Eigen::Upper,
|
||||||
|
Eigen::NaturalOrdering<int>>
|
||||||
|
solver;
|
||||||
|
auto permutation_indices_flat =
|
||||||
|
input_permutation_indices.flat<int32>().data();
|
||||||
|
|
||||||
|
// Invert the fill-in reducing ordering and apply it to the input
|
||||||
|
// sparse matrix.
|
||||||
|
Eigen::Map<
|
||||||
|
Eigen::PermutationMatrix<Eigen::Dynamic, Eigen::Dynamic, int>>
|
||||||
|
permutation(permutation_indices_flat + batch_index * num_rows,
|
||||||
|
num_rows);
|
||||||
|
auto permutation_inverse = permutation.inverse();
|
||||||
|
|
||||||
|
SparseMatrix permuted_sparse_matrix;
|
||||||
|
permuted_sparse_matrix.template selfadjointView<Eigen::Upper>() =
|
||||||
|
sparse_matrix.template selfadjointView<Eigen::Upper>()
|
||||||
|
.twistedBy(permutation_inverse);
|
||||||
|
|
||||||
|
// Compute the Cholesky decomposition.
|
||||||
|
solver.compute(permuted_sparse_matrix);
|
||||||
|
if (solver.info() != Eigen::Success) {
|
||||||
|
invalid_input_index = batch_index;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the upper triangular factor, which would end up in the
|
||||||
|
// lower triangular part of the output CSRSparseMatrix when
|
||||||
|
// interpreted in row major format.
|
||||||
|
sparse_cholesky_factors[batch_index] =
|
||||||
|
solver.matrixU().twistedBy(permutation);
|
||||||
|
|
||||||
|
// For now, batch_ptr contains the number of nonzeros in each
|
||||||
|
// batch.
|
||||||
|
batch_ptr_vec(batch_index + 1) =
|
||||||
|
sparse_cholesky_factors[batch_index].nonZeros();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// Check for invalid input.
|
||||||
|
OP_REQUIRES(
|
||||||
|
ctx, invalid_input_index == -1,
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"Sparse Cholesky factorization failed for batch index ",
|
||||||
|
invalid_input_index.load(), ". The input might not be valid."));
|
||||||
|
|
||||||
|
// Compute a cumulative sum to obtain the batch pointers.
|
||||||
|
std::partial_sum(batch_ptr_vec.data(),
|
||||||
|
batch_ptr_vec.data() + batch_size + 1,
|
||||||
|
batch_ptr_vec.data());
|
||||||
|
|
||||||
|
// Allocate output Tensors.
|
||||||
|
const int64 total_nnz = batch_ptr_vec(batch_size);
|
||||||
|
Tensor output_row_ptr(cpu_allocator(), DT_INT32,
|
||||||
|
TensorShape({(num_rows + 1) * batch_size}));
|
||||||
|
Tensor output_col_ind(cpu_allocator(), DT_INT32, TensorShape({total_nnz}));
|
||||||
|
Tensor output_values(cpu_allocator(), DataTypeToEnum<T>::value,
|
||||||
|
TensorShape({total_nnz}));
|
||||||
|
auto output_row_ptr_ptr = output_row_ptr.flat<int32>().data();
|
||||||
|
auto output_col_ind_ptr = output_col_ind.flat<int32>().data();
|
||||||
|
auto output_values_ptr = output_values.flat<T>().data();
|
||||||
|
|
||||||
|
// Copy the output matrices from each batch into the CSRSparseMatrix
|
||||||
|
// Tensors.
|
||||||
|
// TODO(b/129906419): Factor out the copy from Eigen SparseMatrix to
|
||||||
|
// CSRSparseMatrix into common utils. This is also used in
|
||||||
|
// SparseMatrixSparseMatMul.
|
||||||
|
Shard(worker_threads.num_threads, worker_threads.workers, batch_size,
|
||||||
|
(3 * total_nnz) / batch_size /* cost per unit */,
|
||||||
|
[&](int64 batch_begin, int64 batch_end) {
|
||||||
|
for (int64 batch_index = batch_begin; batch_index < batch_end;
|
||||||
|
++batch_index) {
|
||||||
|
const SparseMatrix& cholesky_factor =
|
||||||
|
sparse_cholesky_factors[batch_index];
|
||||||
|
const int64 nnz = cholesky_factor.nonZeros();
|
||||||
|
|
||||||
|
std::copy(cholesky_factor.outerIndexPtr(),
|
||||||
|
cholesky_factor.outerIndexPtr() + num_rows + 1,
|
||||||
|
output_row_ptr_ptr + batch_index * (num_rows + 1));
|
||||||
|
std::copy(cholesky_factor.innerIndexPtr(),
|
||||||
|
cholesky_factor.innerIndexPtr() + nnz,
|
||||||
|
output_col_ind_ptr + batch_ptr_vec(batch_index));
|
||||||
|
std::copy(cholesky_factor.valuePtr(),
|
||||||
|
cholesky_factor.valuePtr() + nnz,
|
||||||
|
output_values_ptr + batch_ptr_vec(batch_index));
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// Create the CSRSparseMatrix instance from its component Tensors and
|
||||||
|
// prepare the Variant output Tensor.
|
||||||
|
CSRSparseMatrix output_csr_matrix;
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
ctx,
|
||||||
|
CSRSparseMatrix::CreateCSRSparseMatrix(
|
||||||
|
DataTypeToEnum<T>::value, input_matrix->dense_shape(), batch_ptr,
|
||||||
|
output_row_ptr, output_col_ind, output_values, &output_csr_matrix));
|
||||||
|
Tensor* output_csr_matrix_tensor;
|
||||||
|
AllocatorAttributes cpu_alloc;
|
||||||
|
cpu_alloc.set_on_host(true);
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
ctx, ctx->allocate_output(0, TensorShape({}), &output_csr_matrix_tensor,
|
||||||
|
cpu_alloc));
|
||||||
|
output_csr_matrix_tensor->scalar<Variant>()() =
|
||||||
|
std::move(output_csr_matrix);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
void ValidateInputs(OpKernelContext* ctx,
|
||||||
|
const CSRSparseMatrix& sparse_matrix,
|
||||||
|
const Tensor& permutation_indices, int* batch_size,
|
||||||
|
int64* num_rows) {
|
||||||
|
OP_REQUIRES(ctx, sparse_matrix.dtype() == DataTypeToEnum<T>::value,
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"Asked for a CSRSparseMatrix of type ",
|
||||||
|
DataTypeString(DataTypeToEnum<T>::value),
|
||||||
|
" but saw dtype: ", DataTypeString(sparse_matrix.dtype())));
|
||||||
|
|
||||||
|
const Tensor& dense_shape = sparse_matrix.dense_shape();
|
||||||
|
const int rank = dense_shape.dim_size(0);
|
||||||
|
OP_REQUIRES(ctx, rank == 2 || rank == 3,
|
||||||
|
errors::InvalidArgument("sparse matrix must have rank 2 or 3; ",
|
||||||
|
"but dense_shape has size ", rank));
|
||||||
|
const int row_dim = (rank == 2) ? 0 : 1;
|
||||||
|
auto dense_shape_vec = dense_shape.vec<int64>();
|
||||||
|
*num_rows = dense_shape_vec(row_dim);
|
||||||
|
const int64 num_cols = dense_shape_vec(row_dim + 1);
|
||||||
|
OP_REQUIRES(ctx, *num_rows == num_cols,
|
||||||
|
errors::InvalidArgument("sparse matrix must be square; got: ",
|
||||||
|
*num_rows, " != ", num_cols));
|
||||||
|
const TensorShape& perm_shape = permutation_indices.shape();
|
||||||
|
OP_REQUIRES(
|
||||||
|
ctx, perm_shape.dims() + 1 == rank,
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"sparse matrix must have the same rank as permutation; got: ", rank,
|
||||||
|
" != ", perm_shape.dims(), " + 1."));
|
||||||
|
OP_REQUIRES(
|
||||||
|
ctx, perm_shape.dim_size(rank - 2) == *num_rows,
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"permutation must have the same number of elements in each batch "
|
||||||
|
"as the number of rows in sparse matrix; got: ",
|
||||||
|
perm_shape.dim_size(rank - 2), " != ", *num_rows));
|
||||||
|
|
||||||
|
*batch_size = sparse_matrix.batch_size();
|
||||||
|
if (*batch_size > 1) {
|
||||||
|
OP_REQUIRES(
|
||||||
|
ctx, perm_shape.dim_size(0) == *batch_size,
|
||||||
|
errors::InvalidArgument("permutation must have the same batch size "
|
||||||
|
"as sparse matrix; got: ",
|
||||||
|
perm_shape.dim_size(0), " != ", *batch_size));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
#define REGISTER_CPU(T) \
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("SparseMatrixSparseCholesky") \
|
||||||
|
.Device(DEVICE_CPU) \
|
||||||
|
.TypeConstraint<T>("type"), \
|
||||||
|
CSRSparseCholeskyCPUOp<T>);
|
||||||
|
REGISTER_CPU(float);
|
||||||
|
REGISTER_CPU(double);
|
||||||
|
REGISTER_CPU(complex64);
|
||||||
|
REGISTER_CPU(complex128);
|
||||||
|
|
||||||
|
#undef REGISTER_CPU
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
651
tensorflow/core/kernels/sparse/sparse_mat_mul_op.cc
Normal file
651
tensorflow/core/kernels/sparse/sparse_mat_mul_op.cc
Normal file
@ -0,0 +1,651 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#define EIGEN_USE_THREADS
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
#define EIGEN_USE_GPU
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <numeric>
|
||||||
|
|
||||||
|
#include "third_party/eigen3/Eigen/SparseCore"
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
|
#include "tensorflow/core/framework/op.h"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
|
#include "tensorflow/core/framework/types.h"
|
||||||
|
#include "tensorflow/core/framework/variant_op_registry.h"
|
||||||
|
#include "tensorflow/core/kernels/dense_update_functor.h"
|
||||||
|
#include "tensorflow/core/kernels/sparse/kernels.h"
|
||||||
|
#include "tensorflow/core/kernels/sparse/sparse_matrix.h"
|
||||||
|
#include "tensorflow/core/util/work_sharder.h"
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
#include "tensorflow/core/kernels/cuda_solvers.h"
|
||||||
|
#include "tensorflow/core/kernels/cuda_sparse.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||||
|
typedef Eigen::GpuDevice GPUDevice;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// Swaps the dim sizes at two given dimensions of a TensorShape.
|
||||||
|
// Callers are responsible for making sure the given dimensions are within the
|
||||||
|
// valid dimension range of the TensorShape.
|
||||||
|
void SwapDimSizes(const int dim_a, const int dim_b, TensorShape* shape) {
|
||||||
|
const int64 size_a = shape->dim_size(dim_a);
|
||||||
|
const int64 size_b = shape->dim_size(dim_b);
|
||||||
|
shape->set_dim(dim_a, size_b);
|
||||||
|
shape->set_dim(dim_b, size_a);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
// Op to compute the matrix multiplication of two CSR Sparse Matrices.
|
||||||
|
//
|
||||||
|
// Implements a CPU kernel to perform matrix multiplication using Eigen
|
||||||
|
// SparseMatrix and its Sparse-Sparse matmul. Supports transposing and
|
||||||
|
// adjointing on the fly for both the inputs without actually constructing the
|
||||||
|
// transpose or adjoint.
|
||||||
|
//
|
||||||
|
// This implementation does not support broadcasting. Hence both the input
|
||||||
|
// CSRSparseMatrices must have the same rank. (Either rank 2 or rank 3).
|
||||||
|
//
|
||||||
|
// The output sparse have numeric (non-structural) zeros.
|
||||||
|
// TODO(anudhyan): Consider exposing whether to prune zeros as an attribute in
|
||||||
|
// the op's interface.
|
||||||
|
//
|
||||||
|
// If multiple threads are available, we parallelize across multiple batches
|
||||||
|
// using Eigen ThreadPool. Within a single batch, we run in single threaded mode
|
||||||
|
// because Eigen's Sparse-Sparse matmul doesn't support multithreading.
|
||||||
|
//
|
||||||
|
// TODO(b/126472741): Due to the multiple batches of a 3D CSRSparseMatrix being
|
||||||
|
// laid out in contiguous memory, this implementation allocates memory to store
|
||||||
|
// a temporary copy of the matrix product. Consequently, it uses roughly twice
|
||||||
|
// the amount of memory that it needs to. This may cause a memory blowup for
|
||||||
|
// sparse matrices with a high number of non-zero elements.
|
||||||
|
template <typename T>
|
||||||
|
class CSRSparseMatMulCPUOp : public OpKernel {
|
||||||
|
using SparseMatrix = Eigen::SparseMatrix<T, Eigen::RowMajor>;
|
||||||
|
|
||||||
|
public:
|
||||||
|
explicit CSRSparseMatMulCPUOp(OpKernelConstruction* c) : OpKernel(c) {
|
||||||
|
OP_REQUIRES_OK(c, c->GetAttr("transpose_a", &transpose_a_));
|
||||||
|
OP_REQUIRES_OK(c, c->GetAttr("transpose_b", &transpose_b_));
|
||||||
|
OP_REQUIRES_OK(c, c->GetAttr("adjoint_a", &adjoint_a_));
|
||||||
|
OP_REQUIRES(c, !(adjoint_a_ && transpose_a_),
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"Only one of adjoint_a and transpose_a may be true."));
|
||||||
|
OP_REQUIRES_OK(c, c->GetAttr("adjoint_b", &adjoint_b_));
|
||||||
|
OP_REQUIRES(c, !(adjoint_b_ && transpose_b_),
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"Only one of adjoint_b and transpose_b may be true."));
|
||||||
|
}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* ctx) final {
|
||||||
|
const CSRSparseMatrix* input_matrix_a;
|
||||||
|
const CSRSparseMatrix* input_matrix_b;
|
||||||
|
// TODO(anudhyan): Factor out common validation logic in CPU and GPU Ops
|
||||||
|
// into a common base class.
|
||||||
|
OP_REQUIRES_OK(ctx, ExtractVariantFromInput(ctx, 0, &input_matrix_a));
|
||||||
|
OP_REQUIRES_OK(ctx, ExtractVariantFromInput(ctx, 1, &input_matrix_b));
|
||||||
|
OP_REQUIRES(ctx, input_matrix_a->dtype() == DataTypeToEnum<T>::value,
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"dtype of a is not equal to 'type': ",
|
||||||
|
DataTypeString(input_matrix_a->dtype()), " vs. ",
|
||||||
|
DataTypeString(DataTypeToEnum<T>::value)));
|
||||||
|
OP_REQUIRES(ctx, input_matrix_b->dtype() == DataTypeToEnum<T>::value,
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"dtype of b is not equal to 'type': ",
|
||||||
|
DataTypeString(input_matrix_b->dtype()), " vs. ",
|
||||||
|
DataTypeString(DataTypeToEnum<T>::value)));
|
||||||
|
OP_REQUIRES(ctx,
|
||||||
|
input_matrix_a->batch_size() == input_matrix_b->batch_size(),
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"Batch sizes of A and B do not agree. Batch sizes are: ",
|
||||||
|
input_matrix_a->batch_size(), " vs. ",
|
||||||
|
input_matrix_b->batch_size()));
|
||||||
|
|
||||||
|
// Validate input_matrix_a's and input_matrix_b's shapes
|
||||||
|
TensorShape a_shape;
|
||||||
|
TensorShape b_shape;
|
||||||
|
OP_REQUIRES_OK(ctx,
|
||||||
|
TensorShapeUtils::MakeShape(
|
||||||
|
input_matrix_a->dense_shape().vec<int64>(), &a_shape));
|
||||||
|
OP_REQUIRES_OK(ctx,
|
||||||
|
TensorShapeUtils::MakeShape(
|
||||||
|
input_matrix_b->dense_shape().vec<int64>(), &b_shape));
|
||||||
|
|
||||||
|
const int rank = a_shape.dims();
|
||||||
|
const int row_dim = (rank == 2) ? 0 : 1;
|
||||||
|
if (transpose_a_ || adjoint_a_)
|
||||||
|
SwapDimSizes(row_dim, row_dim + 1, &a_shape);
|
||||||
|
if (transpose_b_ || adjoint_b_)
|
||||||
|
SwapDimSizes(row_dim, row_dim + 1, &b_shape);
|
||||||
|
|
||||||
|
OP_REQUIRES(
|
||||||
|
ctx, a_shape.dim_size(row_dim + 1) == b_shape.dim_size(row_dim),
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"Inner product dimensions of A and B do not agree. Shapes are: ",
|
||||||
|
a_shape.DebugString(), " vs. ", b_shape.DebugString()));
|
||||||
|
|
||||||
|
// Infer the output shape of the matrix product.
|
||||||
|
// TODO(ebrevdo): MatMul support for broadcasting at least in the
|
||||||
|
// batch dimension.
|
||||||
|
const int batch_size = input_matrix_a->batch_size();
|
||||||
|
Tensor output_shape(cpu_allocator(), DT_INT64, TensorShape({rank}));
|
||||||
|
auto output_shape_vec = output_shape.vec<int64>();
|
||||||
|
if (rank == 3) output_shape_vec(0) = batch_size;
|
||||||
|
output_shape_vec(row_dim) = a_shape.dim_size(row_dim);
|
||||||
|
output_shape_vec(row_dim + 1) = b_shape.dim_size(row_dim + 1);
|
||||||
|
|
||||||
|
// Set batch pointers.
|
||||||
|
Tensor batch_ptr(cpu_allocator(), DT_INT32, TensorShape({batch_size + 1}));
|
||||||
|
auto batch_ptr_vec = batch_ptr.vec<int32>();
|
||||||
|
batch_ptr_vec(0) = 0;
|
||||||
|
|
||||||
|
// Store intermediate matrix products for each batch.
|
||||||
|
// TODO(b/126472741): For a single batch, consider reusing the
|
||||||
|
// SparseMatrices' buffers to construct the CSRSparseMatrix to prevent 2x
|
||||||
|
// memory usage.
|
||||||
|
std::vector<SparseMatrix> output_matrices(batch_size);
|
||||||
|
|
||||||
|
auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
|
||||||
|
// Estimate the cost per batch per as num_output_rows times the product of
|
||||||
|
// average number of nonzeros per row.
|
||||||
|
const int64 num_output_rows = output_shape_vec(row_dim);
|
||||||
|
const double avg_nnz_per_row_a =
|
||||||
|
input_matrix_a->total_nnz() /
|
||||||
|
static_cast<double>(a_shape.dim_size(row_dim) * batch_size);
|
||||||
|
const double avg_nnz_per_row_b =
|
||||||
|
input_matrix_b->total_nnz() /
|
||||||
|
static_cast<double>(b_shape.dim_size(row_dim) * batch_size);
|
||||||
|
const int64 matmul_cost_per_batch =
|
||||||
|
num_output_rows * (avg_nnz_per_row_a * avg_nnz_per_row_b);
|
||||||
|
|
||||||
|
// Parallelize matrix multiplication across batches.
|
||||||
|
Shard(worker_threads.num_threads, worker_threads.workers, batch_size,
|
||||||
|
matmul_cost_per_batch, [&](int64 batch_begin, int64 batch_end) {
|
||||||
|
for (int64 batch_idx = batch_begin; batch_idx < batch_end;
|
||||||
|
++batch_idx) {
|
||||||
|
// For each batch, map the CSRSparseMatrix as Eigen SparseMatrix
|
||||||
|
// without copying the underlying data.
|
||||||
|
auto a_ref = GetSparseMatrixRef(*input_matrix_a, rank, batch_idx,
|
||||||
|
transpose_a_, adjoint_a_);
|
||||||
|
auto b_ref = GetSparseMatrixRef(*input_matrix_b, rank, batch_idx,
|
||||||
|
transpose_b_, adjoint_b_);
|
||||||
|
|
||||||
|
// Matrix multiply while *not* pruning numerical zeros on the fly.
|
||||||
|
// Allocates output SparseMatrix and moves it to our list of
|
||||||
|
// output_matrices.
|
||||||
|
output_matrices[batch_idx] = a_ref * b_ref;
|
||||||
|
|
||||||
|
// For now, batch_ptr contains the number of nonzeros in each
|
||||||
|
// batch.
|
||||||
|
batch_ptr_vec(batch_idx + 1) =
|
||||||
|
output_matrices[batch_idx].nonZeros();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// Compute the cumulative sum to obtain the batch pointers.
|
||||||
|
std::partial_sum(batch_ptr_vec.data(),
|
||||||
|
batch_ptr_vec.data() + batch_size + 1,
|
||||||
|
batch_ptr_vec.data());
|
||||||
|
const int64 total_nnz = batch_ptr_vec(batch_size);
|
||||||
|
|
||||||
|
// Allocate output tensors.
|
||||||
|
Tensor output_row_ptr(cpu_allocator(), DT_INT32,
|
||||||
|
TensorShape({(num_output_rows + 1) * batch_size}));
|
||||||
|
Tensor output_col_ind(cpu_allocator(), DT_INT32, TensorShape({total_nnz}));
|
||||||
|
Tensor output_values(cpu_allocator(), DataTypeToEnum<T>::value,
|
||||||
|
TensorShape({total_nnz}));
|
||||||
|
auto output_row_ptr_ptr = output_row_ptr.flat<int32>().data();
|
||||||
|
auto output_col_ind_ptr = output_col_ind.flat<int32>().data();
|
||||||
|
auto output_values_ptr = output_values.flat<T>().data();
|
||||||
|
|
||||||
|
// Copy the output matrices from each batch into the CSRSparseMatrix
|
||||||
|
// tensors.
|
||||||
|
Shard(worker_threads.num_threads, worker_threads.workers, batch_size,
|
||||||
|
(3 * total_nnz) / batch_size /* cost per unit */,
|
||||||
|
[&](int64 batch_begin, int64 batch_end) {
|
||||||
|
for (int64 batch_idx = batch_begin; batch_idx < batch_end;
|
||||||
|
++batch_idx) {
|
||||||
|
const SparseMatrix& output_matrix = output_matrices[batch_idx];
|
||||||
|
const int64 nnz = output_matrix.nonZeros();
|
||||||
|
std::copy(output_matrix.outerIndexPtr(),
|
||||||
|
output_matrix.outerIndexPtr() + num_output_rows + 1,
|
||||||
|
output_row_ptr_ptr + batch_idx * (num_output_rows + 1));
|
||||||
|
std::copy(output_matrix.innerIndexPtr(),
|
||||||
|
output_matrix.innerIndexPtr() + nnz,
|
||||||
|
output_col_ind_ptr + batch_ptr_vec(batch_idx));
|
||||||
|
std::copy(output_matrix.valuePtr(),
|
||||||
|
output_matrix.valuePtr() + nnz,
|
||||||
|
output_values_ptr + batch_ptr_vec(batch_idx));
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// Create the CSRSparseMatrix object from its component Tensors and prepare
|
||||||
|
// the Variant output Tensor.
|
||||||
|
CSRSparseMatrix output_csr_matrix;
|
||||||
|
OP_REQUIRES_OK(ctx, CSRSparseMatrix::CreateCSRSparseMatrix(
|
||||||
|
DataTypeToEnum<T>::value, output_shape, batch_ptr,
|
||||||
|
output_row_ptr, output_col_ind, output_values,
|
||||||
|
&output_csr_matrix));
|
||||||
|
Tensor* output_csr_matrix_tensor;
|
||||||
|
AllocatorAttributes cpu_alloc;
|
||||||
|
cpu_alloc.set_on_host(true);
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
ctx, ctx->allocate_output(0, TensorShape({}), &output_csr_matrix_tensor,
|
||||||
|
cpu_alloc));
|
||||||
|
output_csr_matrix_tensor->scalar<Variant>()() =
|
||||||
|
std::move(output_csr_matrix);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Returns an Eigen::Ref expression of a SparseMatrix; which points to the
|
||||||
|
// underlying memory of the given CSRSparseMatrix.
|
||||||
|
Eigen::Ref<const SparseMatrix> GetSparseMatrixRef(
|
||||||
|
const CSRSparseMatrix& csr_matrix, const int rank, const int batch_index,
|
||||||
|
const bool transpose, const bool adjoint) {
|
||||||
|
const auto dense_shape = csr_matrix.dense_shape().vec<int64>();
|
||||||
|
const int64 num_rows = dense_shape(rank == 2 ? 0 : 1);
|
||||||
|
const int64 num_cols = dense_shape(rank == 2 ? 1 : 2);
|
||||||
|
|
||||||
|
Eigen::Map<const SparseMatrix> sparse_matrix(
|
||||||
|
num_rows, num_cols, csr_matrix.nnz(batch_index),
|
||||||
|
csr_matrix.row_pointers_vec(batch_index).data(),
|
||||||
|
csr_matrix.col_indices_vec(batch_index).data(),
|
||||||
|
csr_matrix.values_vec<T>(batch_index).data());
|
||||||
|
|
||||||
|
// The transpose/adjoint expressions are not actually evaluated until
|
||||||
|
// necessary. Hence we don't create copies or modify the input matrix
|
||||||
|
// inplace.
|
||||||
|
if (transpose) return sparse_matrix.transpose();
|
||||||
|
if (adjoint) return sparse_matrix.adjoint();
|
||||||
|
return sparse_matrix;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool transpose_a_;
|
||||||
|
bool transpose_b_;
|
||||||
|
bool adjoint_a_;
|
||||||
|
bool adjoint_b_;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Device, typename T>
|
||||||
|
class CSRSparseMatMulGPUOp : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit CSRSparseMatMulGPUOp(OpKernelConstruction* c) : OpKernel(c) {
|
||||||
|
OP_REQUIRES_OK(c, c->GetAttr("transpose_a", &transpose_a_));
|
||||||
|
OP_REQUIRES_OK(c, c->GetAttr("transpose_b", &transpose_b_));
|
||||||
|
bool adjoint_a;
|
||||||
|
OP_REQUIRES_OK(c, c->GetAttr("adjoint_a", &adjoint_a));
|
||||||
|
OP_REQUIRES(c, !(adjoint_a && transpose_a_),
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"Only one of adjoint_a and transpose_a may be true."));
|
||||||
|
bool adjoint_b;
|
||||||
|
OP_REQUIRES_OK(c, c->GetAttr("adjoint_b", &adjoint_b));
|
||||||
|
OP_REQUIRES(c, !(adjoint_b && transpose_b_),
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"Only one of adjoint_b and transpose_b may be true."));
|
||||||
|
conjugate_a_ = adjoint_a;
|
||||||
|
conjugate_b_ = adjoint_b;
|
||||||
|
transpose_a_ = transpose_a_ || adjoint_a;
|
||||||
|
transpose_b_ = transpose_b_ || adjoint_b;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* ctx) final {
|
||||||
|
const CSRSparseMatrix* a_matrix;
|
||||||
|
const CSRSparseMatrix* b_matrix;
|
||||||
|
OP_REQUIRES_OK(ctx, ExtractVariantFromInput(ctx, 0, &a_matrix));
|
||||||
|
OP_REQUIRES_OK(ctx, ExtractVariantFromInput(ctx, 1, &b_matrix));
|
||||||
|
OP_REQUIRES(
|
||||||
|
ctx, a_matrix->dtype() == DataTypeToEnum<T>::value,
|
||||||
|
errors::InvalidArgument("dtype of a is not equal to 'type': ",
|
||||||
|
DataTypeString(a_matrix->dtype()), " vs. ",
|
||||||
|
DataTypeString(DataTypeToEnum<T>::value)));
|
||||||
|
OP_REQUIRES(
|
||||||
|
ctx, b_matrix->dtype() == DataTypeToEnum<T>::value,
|
||||||
|
errors::InvalidArgument("dtype of b is not equal to 'type': ",
|
||||||
|
DataTypeString(b_matrix->dtype()), " vs. ",
|
||||||
|
DataTypeString(DataTypeToEnum<T>::value)));
|
||||||
|
|
||||||
|
// TODO(ebrevdo): MatMul support for broadcasting at least in the
|
||||||
|
// batch dimension.
|
||||||
|
auto a_dense_shape = a_matrix->dense_shape().vec<int64>();
|
||||||
|
auto b_dense_shape = b_matrix->dense_shape().vec<int64>();
|
||||||
|
|
||||||
|
TensorShape a_tensor_shape;
|
||||||
|
TensorShape b_tensor_shape;
|
||||||
|
OP_REQUIRES_OK(ctx,
|
||||||
|
TensorShapeUtils::MakeShape(a_dense_shape, &a_tensor_shape));
|
||||||
|
OP_REQUIRES_OK(ctx,
|
||||||
|
TensorShapeUtils::MakeShape(b_dense_shape, &b_tensor_shape));
|
||||||
|
|
||||||
|
const int rank = a_tensor_shape.dims();
|
||||||
|
const int row_dim = (rank == 2) ? 0 : 1;
|
||||||
|
|
||||||
|
const int64 a_inner_dim =
|
||||||
|
a_tensor_shape.dim_size(transpose_a_ ? row_dim : row_dim + 1);
|
||||||
|
const int64 b_inner_dim =
|
||||||
|
b_tensor_shape.dim_size(transpose_b_ ? row_dim + 1 : row_dim);
|
||||||
|
|
||||||
|
const int batch_size = a_matrix->batch_size();
|
||||||
|
|
||||||
|
OP_REQUIRES(
|
||||||
|
ctx, a_inner_dim == b_inner_dim,
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"Inner product dimensions of A and B do not agree. Shapes are: ",
|
||||||
|
a_tensor_shape.DebugString(), " vs. ",
|
||||||
|
b_tensor_shape.DebugString()));
|
||||||
|
|
||||||
|
Tensor c_dense_shape_t(cpu_allocator(), DT_INT64, TensorShape({rank}));
|
||||||
|
auto c_dense_shape = c_dense_shape_t.vec<int64>();
|
||||||
|
|
||||||
|
if (rank == 3) c_dense_shape(0) = batch_size;
|
||||||
|
c_dense_shape(row_dim) =
|
||||||
|
a_tensor_shape.dim_size(transpose_a_ ? row_dim + 1 : row_dim);
|
||||||
|
c_dense_shape(row_dim + 1) =
|
||||||
|
b_tensor_shape.dim_size(transpose_b_ ? row_dim : row_dim + 1);
|
||||||
|
|
||||||
|
const int64 rows = c_dense_shape((rank == 2) ? 0 : 1);
|
||||||
|
|
||||||
|
CSRSparseMatrix c;
|
||||||
|
Tensor c_row_ptrs;
|
||||||
|
Tensor c_col_inds;
|
||||||
|
Tensor c_values;
|
||||||
|
|
||||||
|
// TODO(ebrevdo): Re-enable transposing within the GEMM kernel when cuSparse
|
||||||
|
// stops spitting out CUSPARSE_STATUS_INTERNAL_ERROR values for transposes.
|
||||||
|
functor::CSRSparseSparseMatrixMatMul<Device, T> csr_gemm(
|
||||||
|
ctx, /*transpose_a=*/false, /*adjoint_a=*/false, /*transpose_b=*/false);
|
||||||
|
OP_REQUIRES_OK(ctx, csr_gemm.Initialize());
|
||||||
|
|
||||||
|
Tensor c_batch_ptr_t(cpu_allocator(), DT_INT32,
|
||||||
|
TensorShape({batch_size + 1}));
|
||||||
|
auto c_batch_ptr = c_batch_ptr_t.vec<int32>();
|
||||||
|
c_batch_ptr(0) = 0;
|
||||||
|
|
||||||
|
Tensor c_row_ptr_t;
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->allocate_temp(
|
||||||
|
DT_INT32, TensorShape({batch_size * (rows + 1)}),
|
||||||
|
&c_row_ptr_t));
|
||||||
|
auto c_row_ptr = c_row_ptr_t.vec<int32>();
|
||||||
|
|
||||||
|
// Possibly transpose a.
|
||||||
|
const CSRSparseMatrix* a_input_matrix;
|
||||||
|
// If we need to transpose a, we will store the result temporarily
|
||||||
|
// in the object below.
|
||||||
|
CSRSparseMatrix a_matrix_transposed;
|
||||||
|
if (!transpose_a_) {
|
||||||
|
a_input_matrix = a_matrix;
|
||||||
|
} else {
|
||||||
|
functor::CSRSparseMatrixTranspose<Device, T> transpose;
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
ctx, transpose(ctx, conjugate_a_, *a_matrix, &a_matrix_transposed));
|
||||||
|
a_input_matrix = &a_matrix_transposed;
|
||||||
|
}
|
||||||
|
auto a_input_dense_shape = a_input_matrix->dense_shape().vec<int64>();
|
||||||
|
|
||||||
|
// Possibly transpose b.
|
||||||
|
const CSRSparseMatrix* b_input_matrix;
|
||||||
|
// If we need to transpose a, we will store the result temporarily
|
||||||
|
// in the object below.
|
||||||
|
CSRSparseMatrix b_matrix_transposed;
|
||||||
|
if (!transpose_b_) {
|
||||||
|
b_input_matrix = b_matrix;
|
||||||
|
} else {
|
||||||
|
functor::CSRSparseMatrixTranspose<Device, T> transpose;
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
ctx, transpose(ctx, conjugate_b_, *b_matrix, &b_matrix_transposed));
|
||||||
|
b_input_matrix = &b_matrix_transposed;
|
||||||
|
}
|
||||||
|
auto b_input_dense_shape = b_input_matrix->dense_shape().vec<int64>();
|
||||||
|
|
||||||
|
for (int i = 0; i < batch_size; ++i) {
|
||||||
|
// Calculate output sizes for all minibatch entries.
|
||||||
|
// Store in c_batch_ptr and update c_row_ptrs.
|
||||||
|
ConstCSRComponent<T> a_comp{a_input_matrix->row_pointers_vec(i),
|
||||||
|
a_input_matrix->col_indices_vec(i),
|
||||||
|
a_input_matrix->values_vec<T>(i),
|
||||||
|
a_input_dense_shape};
|
||||||
|
ConstCSRComponent<T> b_comp{b_input_matrix->row_pointers_vec(i),
|
||||||
|
b_input_matrix->col_indices_vec(i),
|
||||||
|
b_input_matrix->values_vec<T>(i),
|
||||||
|
b_input_dense_shape};
|
||||||
|
|
||||||
|
TTypes<int32>::UnalignedVec c_row_ptr_i(&c_row_ptr(i * (rows + 1)),
|
||||||
|
rows + 1);
|
||||||
|
|
||||||
|
int c_nnz_i;
|
||||||
|
OP_REQUIRES_OK(ctx, csr_gemm.GetOutputStructure(a_comp, b_comp,
|
||||||
|
c_row_ptr_i, &c_nnz_i));
|
||||||
|
c_batch_ptr(i + 1) = c_batch_ptr(i) + c_nnz_i;
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor c_col_ind_t;
|
||||||
|
Tensor c_values_t;
|
||||||
|
|
||||||
|
const int total_nnz = c_batch_ptr(batch_size);
|
||||||
|
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_INT32, TensorShape({total_nnz}),
|
||||||
|
&c_col_ind_t));
|
||||||
|
OP_REQUIRES_OK(ctx,
|
||||||
|
ctx->allocate_temp(DataTypeToEnum<T>::value,
|
||||||
|
TensorShape({total_nnz}), &c_values_t));
|
||||||
|
OP_REQUIRES_OK(ctx,
|
||||||
|
CSRSparseMatrix::CreateCSRSparseMatrix(
|
||||||
|
DataTypeToEnum<T>::value, c_dense_shape_t, c_batch_ptr_t,
|
||||||
|
c_row_ptr_t, c_col_ind_t, c_values_t, &c));
|
||||||
|
|
||||||
|
for (int i = 0; i < batch_size; ++i) {
|
||||||
|
ConstCSRComponent<T> a_comp{a_input_matrix->row_pointers_vec(i),
|
||||||
|
a_input_matrix->col_indices_vec(i),
|
||||||
|
a_input_matrix->values_vec<T>(i),
|
||||||
|
a_input_dense_shape};
|
||||||
|
ConstCSRComponent<T> b_comp{b_input_matrix->row_pointers_vec(i),
|
||||||
|
b_input_matrix->col_indices_vec(i),
|
||||||
|
b_input_matrix->values_vec<T>(i),
|
||||||
|
b_input_dense_shape};
|
||||||
|
CSRComponent<T> c_comp{c.row_pointers_vec(i), c.col_indices_vec(i),
|
||||||
|
c.values_vec<T>(i), c_dense_shape};
|
||||||
|
OP_REQUIRES_OK(ctx, csr_gemm.Compute(a_comp, b_comp, &c_comp));
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor c_t(cpu_allocator(), DT_VARIANT, TensorShape({}));
|
||||||
|
c_t.scalar<Variant>()() = std::move(c);
|
||||||
|
ctx->set_output(0, c_t);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool transpose_a_;
|
||||||
|
bool transpose_b_;
|
||||||
|
bool conjugate_a_;
|
||||||
|
bool conjugate_b_;
|
||||||
|
};
|
||||||
|
|
||||||
|
#define REGISTER_CPU(T) \
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("SparseMatrixSparseMatMul") \
|
||||||
|
.Device(DEVICE_CPU) \
|
||||||
|
.TypeConstraint<T>("type"), \
|
||||||
|
CSRSparseMatMulCPUOp<T>);
|
||||||
|
|
||||||
|
REGISTER_CPU(float)
|
||||||
|
REGISTER_CPU(double)
|
||||||
|
REGISTER_CPU(complex64)
|
||||||
|
REGISTER_CPU(complex128)
|
||||||
|
|
||||||
|
#undef REGISTER_CPU
|
||||||
|
|
||||||
|
#define REGISTER(DEV, T) \
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("SparseMatrixSparseMatMul") \
|
||||||
|
.Device(DEVICE_##DEV) \
|
||||||
|
.TypeConstraint<T>("type"), \
|
||||||
|
CSRSparseMatMulGPUOp<DEV##Device, T>);
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
|
||||||
|
#define REGISTER_GPU(T) REGISTER(GPU, T)
|
||||||
|
|
||||||
|
REGISTER_GPU(float)
|
||||||
|
REGISTER_GPU(double)
|
||||||
|
REGISTER_GPU(complex64)
|
||||||
|
REGISTER_GPU(complex128)
|
||||||
|
|
||||||
|
#undef REGISTER_GPU
|
||||||
|
|
||||||
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
|
#undef REGISTER
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
namespace functor {
|
||||||
|
template <typename T>
|
||||||
|
struct CSRSparseSparseMatrixMatMul<GPUDevice, T>
|
||||||
|
: public CSRStructureModifyingFunctor<GPUDevice, T> {
|
||||||
|
explicit CSRSparseSparseMatrixMatMul(OpKernelContext* ctx, bool transpose_a,
|
||||||
|
bool adjoint_a, bool transpose_b)
|
||||||
|
: ctx_(ctx),
|
||||||
|
cuda_sparse_(ctx),
|
||||||
|
initialized_(false),
|
||||||
|
transpose_a_(transpose_a),
|
||||||
|
adjoint_a_(adjoint_a),
|
||||||
|
transpose_b_(transpose_b) {
|
||||||
|
// TODO(ebrevdo): Figure out why transposed implementations crash cuSparse.
|
||||||
|
transA_ = transpose_a ? (adjoint_a ? CUSPARSE_OPERATION_TRANSPOSE
|
||||||
|
: CUSPARSE_OPERATION_CONJUGATE_TRANSPOSE)
|
||||||
|
: CUSPARSE_OPERATION_NON_TRANSPOSE;
|
||||||
|
transB_ = transpose_b ? CUSPARSE_OPERATION_TRANSPOSE
|
||||||
|
: CUSPARSE_OPERATION_NON_TRANSPOSE;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status Initialize() {
|
||||||
|
if (adjoint_a_ && transpose_a_) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Only one of adjoint_a and transpose_a may be true.");
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(cuda_sparse_.Initialize());
|
||||||
|
TF_RETURN_IF_ERROR(descrA_.Initialize());
|
||||||
|
TF_RETURN_IF_ERROR(descrB_.Initialize());
|
||||||
|
TF_RETURN_IF_ERROR(descrC_.Initialize());
|
||||||
|
initialized_ = true;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GetOutputStructure(const ConstCSRComponent<T>& a,
|
||||||
|
const ConstCSRComponent<T>& b,
|
||||||
|
TTypes<int32>::UnalignedVec c_row_ptr,
|
||||||
|
int* output_nnz) {
|
||||||
|
DCHECK(initialized_);
|
||||||
|
|
||||||
|
const int m =
|
||||||
|
a.dense_shape_host(a.dense_shape_host.size() - (transpose_a_ ? 1 : 2));
|
||||||
|
if (!transpose_a_) {
|
||||||
|
DCHECK_EQ(m, a.row_ptr.size() - 1);
|
||||||
|
}
|
||||||
|
DCHECK_EQ(m, c_row_ptr.size() - 1);
|
||||||
|
const int k =
|
||||||
|
a.dense_shape_host(a.dense_shape_host.size() - (transpose_a_ ? 2 : 1));
|
||||||
|
if (!transpose_b_) {
|
||||||
|
DCHECK_EQ(k, b.row_ptr.size() - 1);
|
||||||
|
}
|
||||||
|
const int nnzA = a.col_ind.size();
|
||||||
|
const int nnzB = b.col_ind.size();
|
||||||
|
|
||||||
|
const int n =
|
||||||
|
b.dense_shape_host(b.dense_shape_host.size() - (transpose_b_ ? 2 : 1));
|
||||||
|
|
||||||
|
*output_nnz = -1;
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(cuda_sparse_.CsrgemmNnz(
|
||||||
|
transA_, transB_, m, n, k, descrA_.descr(), nnzA, a.row_ptr.data(),
|
||||||
|
a.col_ind.data(), descrB_.descr(), nnzB, b.row_ptr.data(),
|
||||||
|
b.col_ind.data(), descrC_.descr(), c_row_ptr.data(), output_nnz));
|
||||||
|
|
||||||
|
if (*output_nnz < 0) {
|
||||||
|
return errors::Internal(
|
||||||
|
"CSRMatMul: CsrgemmNnz returned nnzTotalDevHostPtr < 0: ",
|
||||||
|
*output_nnz);
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status Compute(const ConstCSRComponent<T>& a, const ConstCSRComponent<T>& b,
|
||||||
|
CSRComponent<T>* c) {
|
||||||
|
DCHECK(initialized_);
|
||||||
|
|
||||||
|
const int m =
|
||||||
|
a.dense_shape_host(a.dense_shape_host.size() - (transpose_a_ ? 1 : 2));
|
||||||
|
if (!transpose_a_) {
|
||||||
|
DCHECK_EQ(m, a.row_ptr.size() - 1);
|
||||||
|
}
|
||||||
|
DCHECK_EQ(m, c->dense_shape_host(c->dense_shape_host.size() - 2));
|
||||||
|
DCHECK_EQ(m, c->row_ptr.size() - 1);
|
||||||
|
const int k =
|
||||||
|
a.dense_shape_host(a.dense_shape_host.size() - (transpose_a_ ? 2 : 1));
|
||||||
|
if (!transpose_b_) {
|
||||||
|
DCHECK_EQ(k, b.row_ptr.size() - 1);
|
||||||
|
}
|
||||||
|
const int nnzA = a.col_ind.size();
|
||||||
|
const int nnzB = b.col_ind.size();
|
||||||
|
|
||||||
|
const int n =
|
||||||
|
b.dense_shape_host(b.dense_shape_host.size() - (transpose_b_ ? 2 : 1));
|
||||||
|
DCHECK_EQ(n, c->dense_shape_host(c->dense_shape_host.size() - 1));
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(cuda_sparse_.Csrgemm(
|
||||||
|
transA_, transB_, m, k, n, descrA_.descr(), nnzA, a.values.data(),
|
||||||
|
a.row_ptr.data(), a.col_ind.data(), descrB_.descr(), nnzB,
|
||||||
|
b.values.data(), b.row_ptr.data(), b.col_ind.data(), descrC_.descr(),
|
||||||
|
c->values.data(), c->row_ptr.data(), c->col_ind.data()));
|
||||||
|
|
||||||
|
// TODO(ebrevdo): Add a flag to CSRSparseMatrix whether matrix
|
||||||
|
// columns are sorted? Above operation leads to unsorted columns.
|
||||||
|
// For now, here is an example of how to ensure the output columns
|
||||||
|
// are sorted. Leaving here in case we realize we need to ensure
|
||||||
|
// sorted columns in the future.
|
||||||
|
//
|
||||||
|
// TF_RETURN_IF_ERROR(cuda_sparse.Csru2csr(
|
||||||
|
// m, n, nnzTotalDevHostPtr, descrA_.descr(), c->values.data(),
|
||||||
|
// c->row_ptr.data(), c->col_ind.data()));
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
OpKernelContext* ctx_;
|
||||||
|
CudaSparse cuda_sparse_;
|
||||||
|
bool initialized_;
|
||||||
|
bool transpose_a_;
|
||||||
|
bool adjoint_a_;
|
||||||
|
bool transpose_b_;
|
||||||
|
CudaSparseMatrixDescriptor descrA_;
|
||||||
|
CudaSparseMatrixDescriptor descrB_;
|
||||||
|
CudaSparseMatrixDescriptor descrC_;
|
||||||
|
cusparseOperation_t transA_;
|
||||||
|
cusparseOperation_t transB_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace functor
|
||||||
|
|
||||||
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
43
tensorflow/core/kernels/sparse/sparse_matrix.cc
Normal file
43
tensorflow/core/kernels/sparse/sparse_matrix.cc
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#define EIGEN_USE_THREADS
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
#define EIGEN_USE_GPU
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/variant_op_registry.h"
|
||||||
|
#include "tensorflow/core/kernels/sparse/sparse_matrix.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
constexpr const char CSRSparseMatrix::kTypeName[];
|
||||||
|
|
||||||
|
// Register variant decoding function for TF's RPC.
|
||||||
|
REGISTER_UNARY_VARIANT_DECODE_FUNCTION(CSRSparseMatrix,
|
||||||
|
CSRSparseMatrix::kTypeName);
|
||||||
|
|
||||||
|
#define REGISTER_CSR_COPY(DIRECTION) \
|
||||||
|
INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
|
||||||
|
CSRSparseMatrix, DIRECTION, CSRSparseMatrix::DeviceCopy)
|
||||||
|
|
||||||
|
REGISTER_CSR_COPY(VariantDeviceCopyDirection::HOST_TO_DEVICE);
|
||||||
|
REGISTER_CSR_COPY(VariantDeviceCopyDirection::DEVICE_TO_HOST);
|
||||||
|
REGISTER_CSR_COPY(VariantDeviceCopyDirection::DEVICE_TO_DEVICE);
|
||||||
|
|
||||||
|
#undef REGISTER_CSR_COPY
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
640
tensorflow/core/kernels/sparse/sparse_matrix.h
Normal file
640
tensorflow/core/kernels/sparse/sparse_matrix.h
Normal file
@ -0,0 +1,640 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_CORE_KERNELS_SPARSE_SPARSE_MATRIX_H_
|
||||||
|
#define TENSORFLOW_CORE_KERNELS_SPARSE_SPARSE_MATRIX_H_
|
||||||
|
|
||||||
|
#define EIGEN_USE_THREADS
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
#define EIGEN_USE_GPU
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
|
#include "tensorflow/core/framework/variant.h"
|
||||||
|
#include "tensorflow/core/framework/variant_encode_decode.h"
|
||||||
|
#include "tensorflow/core/framework/variant_op_registry.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
class CSRSparseMatrix {
|
||||||
|
// CreateCSRSparseMatrix is the main method used to construct a
|
||||||
|
// CSRSparseMatrix. The representations for both 2D and 3D
|
||||||
|
// (batched) CSR Sparse Matrices are the same:
|
||||||
|
//
|
||||||
|
// dtype: The datatype of the values.
|
||||||
|
// dense_shape: The dense shape of the matrix.
|
||||||
|
// * Host int64 vector, size 2 or 3.
|
||||||
|
// * Takes on values: (rows, cols) or (batch_size, rows, cols).
|
||||||
|
// batch_pointers: Batch offset pointers into col_indices and values.
|
||||||
|
// * Host int32 vector, size (batch_size + 1).
|
||||||
|
// * Takes on values: (0, nnz[0], nnz[0] + nnz[1], ..., total_nnz).
|
||||||
|
// row_pointers: Row offset pointers into col_indices and values.
|
||||||
|
// * Device int32 vector, size ((rows + 1) * batch_size).
|
||||||
|
// * Each block of size (rows + 1) takes on values:
|
||||||
|
// (0, num_rows{b}[0], num_rows{b}[0] + num_rows{b}[1], ..., nnz[b]).
|
||||||
|
// for b = 0 .. batch_size - 1.
|
||||||
|
// col_indices: Column values for the given row and column index.
|
||||||
|
// * Device int32 vector, size total_nnz.
|
||||||
|
// values: Actual values for the given row and column index.
|
||||||
|
// * Device dtype vector, size total_nnz.
|
||||||
|
//
|
||||||
|
// The storage agreement is such that for a given (batch, row, ix):
|
||||||
|
// offset = batch_pointers(batch) + row_pointers(batch * (rows + 1) + row)
|
||||||
|
// col = col_indices(offset + ix)
|
||||||
|
// val = values(offset + ix)
|
||||||
|
// where ix < #nnz columns in (batch, row).
|
||||||
|
// Then:
|
||||||
|
// matrix(batch, row, col) = val.
|
||||||
|
//
|
||||||
|
// All other elements in the dense representation are treated as 0 / empty.
|
||||||
|
//
|
||||||
|
// For example, for a 2D sparse matrix m shaped (3, 4) such that:
|
||||||
|
//
|
||||||
|
// m[0, 0] = 1.0
|
||||||
|
// m[0, 1] = 2.0
|
||||||
|
// m[0, 2] = 3.0
|
||||||
|
// m[2, 2] = 4.0
|
||||||
|
// m[2, 3] = 5.0
|
||||||
|
//
|
||||||
|
// The corresponding representation is:
|
||||||
|
//
|
||||||
|
// dtype: DT_FLOAT
|
||||||
|
// dense_shape: (3, 4)
|
||||||
|
// batch_pointers: (0, 5)
|
||||||
|
// row_pointers: (0, 3, 3, 5)
|
||||||
|
// col_indices: concat((0, 1, 2), (), (2, 3))
|
||||||
|
// values: concat((1.0, 2.0, 3.0), (), (4.0, 5.0))
|
||||||
|
//
|
||||||
|
// For a 3D sparse matrix m shaped (2, 3, 4) such that:
|
||||||
|
//
|
||||||
|
// m[0, 0, 0] = 1.0
|
||||||
|
// m[0, 0, 2] = 2.0
|
||||||
|
// m[0, 2, 3] = 3.0
|
||||||
|
// m[1, 0, 3] = 4.0
|
||||||
|
// m[1, 1, 0] = 5.0
|
||||||
|
//
|
||||||
|
// The corresponding representation is:
|
||||||
|
// dtype: DT_FLOAT
|
||||||
|
// dense_shape: (2, 3, 4)
|
||||||
|
// batch_pointers: (0, 3, 5)
|
||||||
|
// row_pointers: concat((0, 2, 2, 3), (0, 1, 2, 2))
|
||||||
|
// col_indices: concat(concat((0, 2), (), (3,)),
|
||||||
|
// concat((3,), (), (0,)))
|
||||||
|
// values: concat(concat((1.0, 2.0), (3.0,), ()),
|
||||||
|
/// concat((4.0,), (5.0,), ()))
|
||||||
|
//
|
||||||
|
public:
|
||||||
|
static constexpr const char kTypeName[] = "tensorflow::CSRSparseMatrix";
|
||||||
|
|
||||||
|
CSRSparseMatrix() : metadata_{false, DT_INVALID} {}
|
||||||
|
|
||||||
|
CSRSparseMatrix(const CSRSparseMatrix& rhs)
|
||||||
|
: metadata_(rhs.metadata_),
|
||||||
|
dense_shape_(rhs.dense_shape_),
|
||||||
|
batch_pointers_(rhs.batch_pointers_),
|
||||||
|
row_pointers_(rhs.row_pointers_),
|
||||||
|
col_indices_(rhs.col_indices_),
|
||||||
|
values_(rhs.values_) {
|
||||||
|
SetupVecs();
|
||||||
|
}
|
||||||
|
|
||||||
|
CSRSparseMatrix(CSRSparseMatrix&& rhs)
|
||||||
|
: metadata_(rhs.metadata_),
|
||||||
|
dense_shape_(std::move(rhs.dense_shape_)),
|
||||||
|
batch_pointers_(std::move(rhs.batch_pointers_)),
|
||||||
|
row_pointers_(std::move(rhs.row_pointers_)),
|
||||||
|
col_indices_(std::move(rhs.col_indices_)),
|
||||||
|
values_(std::move(rhs.values_)) {
|
||||||
|
SetupVecs();
|
||||||
|
rhs.metadata_.validated = false;
|
||||||
|
rhs.metadata_.dtype = DT_INVALID;
|
||||||
|
rhs.ClearVecs();
|
||||||
|
}
|
||||||
|
|
||||||
|
CSRSparseMatrix& operator=(CSRSparseMatrix&& rhs) {
|
||||||
|
if (this == &rhs) return *this;
|
||||||
|
metadata_ = rhs.metadata_;
|
||||||
|
metadata_.validated = rhs.metadata_.validated;
|
||||||
|
dense_shape_ = std::move(rhs.dense_shape_);
|
||||||
|
batch_pointers_ = std::move(rhs.batch_pointers_);
|
||||||
|
row_pointers_ = std::move(rhs.row_pointers_);
|
||||||
|
col_indices_ = std::move(rhs.col_indices_);
|
||||||
|
values_ = std::move(rhs.values_);
|
||||||
|
SetupVecs();
|
||||||
|
rhs.metadata_ = {false, DT_INVALID};
|
||||||
|
rhs.ClearVecs();
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
static Status CreateCSRSparseMatrix(DataType dtype,
|
||||||
|
const Tensor& dense_shape, // on host
|
||||||
|
const Tensor& batch_pointers, // on host
|
||||||
|
const Tensor& row_pointers,
|
||||||
|
const Tensor& col_indices,
|
||||||
|
const Tensor& values,
|
||||||
|
CSRSparseMatrix* matrix) {
|
||||||
|
*matrix = CSRSparseMatrix(dtype, dense_shape, batch_pointers, row_pointers,
|
||||||
|
col_indices, values);
|
||||||
|
Status s = matrix->Validate();
|
||||||
|
matrix->metadata_.validated = s.ok();
|
||||||
|
matrix->SetupVecs();
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status Validate() const {
|
||||||
|
return ValidateTypesAndShapes(metadata_.dtype, dense_shape_,
|
||||||
|
batch_pointers_, row_pointers_, col_indices_,
|
||||||
|
values_);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Clear() {
|
||||||
|
metadata_ = {false, DT_INVALID};
|
||||||
|
dense_shape_ = Tensor();
|
||||||
|
batch_pointers_ = Tensor();
|
||||||
|
row_pointers_ = Tensor();
|
||||||
|
col_indices_ = Tensor();
|
||||||
|
values_ = Tensor();
|
||||||
|
ClearVecs();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool valid() const {
|
||||||
|
return metadata_.validated && dense_shape_.IsInitialized() &&
|
||||||
|
batch_pointers_.IsInitialized() && row_pointers_.IsInitialized() &&
|
||||||
|
col_indices_.IsInitialized() && values_.IsInitialized() &&
|
||||||
|
dense_shape_.NumElements() > 1 &&
|
||||||
|
batch_pointers_.NumElements() > 0 && row_pointers_.NumElements() > 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
DataType dtype() const {
|
||||||
|
DCHECK(valid());
|
||||||
|
return metadata_.dtype;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline int dims() const {
|
||||||
|
DCHECK(valid());
|
||||||
|
return dense_shape_.NumElements();
|
||||||
|
}
|
||||||
|
|
||||||
|
inline int nnz(int batch) const {
|
||||||
|
DCHECK_LT(batch, batch_size());
|
||||||
|
return (*batch_pointers_vec_)(batch + 1) - (*batch_pointers_vec_)(batch);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline int batch_offset(int batch) const {
|
||||||
|
DCHECK_LT(batch, batch_size());
|
||||||
|
return (*batch_pointers_vec_)(batch);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline int total_nnz() const {
|
||||||
|
DCHECK(valid());
|
||||||
|
return (*batch_pointers_vec_)(batch_size());
|
||||||
|
}
|
||||||
|
|
||||||
|
inline Tensor& dense_shape() {
|
||||||
|
DCHECK(valid());
|
||||||
|
return dense_shape_;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline const Tensor& dense_shape() const {
|
||||||
|
DCHECK(valid());
|
||||||
|
return dense_shape_;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline TTypes<int32>::UnalignedVec row_pointers_vec(int batch) {
|
||||||
|
DCHECK(valid());
|
||||||
|
DCHECK_LT(batch, batch_size());
|
||||||
|
const int64 rows = dense_shape().vec<int64>()((dims() == 2) ? 0 : 1);
|
||||||
|
const int offset = batch * (rows + 1);
|
||||||
|
return TTypes<int32>::UnalignedVec(row_pointers_vec_->data() + offset,
|
||||||
|
rows + 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline TTypes<int32>::UnalignedConstVec row_pointers_vec(int batch) const {
|
||||||
|
DCHECK(valid());
|
||||||
|
DCHECK_LT(batch, batch_size());
|
||||||
|
const int64 rows = dense_shape().vec<int64>()((dims() == 2) ? 0 : 1);
|
||||||
|
const int offset = batch * (rows + 1);
|
||||||
|
return TTypes<int32>::UnalignedConstVec(row_pointers_vec_->data() + offset,
|
||||||
|
rows + 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline TTypes<int32>::UnalignedVec col_indices_vec(int batch) {
|
||||||
|
DCHECK(valid());
|
||||||
|
DCHECK_LT(batch, batch_size());
|
||||||
|
const int offset = (*batch_pointers_vec_)(batch);
|
||||||
|
const int nnz_in_batch = nnz(batch);
|
||||||
|
return TTypes<int32>::UnalignedVec(col_indices_vec_->data() + offset,
|
||||||
|
nnz_in_batch);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline TTypes<int32>::UnalignedConstVec col_indices_vec(int batch) const {
|
||||||
|
DCHECK(valid());
|
||||||
|
DCHECK_LT(batch, batch_size());
|
||||||
|
const int offset = (*batch_pointers_vec_)(batch);
|
||||||
|
const int nnz_in_batch = nnz(batch);
|
||||||
|
return TTypes<int32>::UnalignedConstVec(col_indices_vec_->data() + offset,
|
||||||
|
nnz_in_batch);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline typename TTypes<T>::UnalignedVec values_vec(int batch) {
|
||||||
|
DCHECK(valid());
|
||||||
|
DCHECK_LT(batch, batch_size());
|
||||||
|
const int offset = (*batch_pointers_vec_)(batch);
|
||||||
|
const int nnz_in_batch = nnz(batch);
|
||||||
|
return typename TTypes<T>::UnalignedVec(&(values().vec<T>()(offset)),
|
||||||
|
nnz_in_batch);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline typename TTypes<T>::UnalignedConstVec values_vec(int batch) const {
|
||||||
|
DCHECK(valid());
|
||||||
|
DCHECK_LT(batch, batch_size());
|
||||||
|
const int offset = (*batch_pointers_vec_)(batch);
|
||||||
|
const int nnz_in_batch = nnz(batch);
|
||||||
|
return typename TTypes<T>::UnalignedConstVec(&(values().vec<T>()(offset)),
|
||||||
|
nnz_in_batch);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline Tensor& row_pointers() {
|
||||||
|
DCHECK(valid());
|
||||||
|
return row_pointers_;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline const Tensor& row_pointers() const {
|
||||||
|
DCHECK(valid());
|
||||||
|
return row_pointers_;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline Tensor& col_indices() {
|
||||||
|
DCHECK(valid());
|
||||||
|
return col_indices_;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline const Tensor& col_indices() const {
|
||||||
|
DCHECK(valid());
|
||||||
|
return col_indices_;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline Tensor& values() {
|
||||||
|
DCHECK(valid());
|
||||||
|
return values_;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline const Tensor& values() const {
|
||||||
|
DCHECK(valid());
|
||||||
|
return values_;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline Tensor& batch_pointers() {
|
||||||
|
DCHECK(valid());
|
||||||
|
return batch_pointers_;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline const Tensor& batch_pointers() const {
|
||||||
|
DCHECK(valid());
|
||||||
|
return batch_pointers_;
|
||||||
|
}
|
||||||
|
|
||||||
|
string TypeName() const { return kTypeName; }
|
||||||
|
|
||||||
|
// TODO(ebrevdo): A better debug string.
|
||||||
|
string DebugString() const { return dense_shape_.DebugString(); }
|
||||||
|
|
||||||
|
// Returns the number of elements. This is equal to 1 if the
|
||||||
|
// CSRSparseMatrix is a singleton matrix (dense_shape is length 2).
|
||||||
|
int batch_size() const {
|
||||||
|
DCHECK(valid());
|
||||||
|
return batch_pointers_.NumElements() - 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Decode(const VariantTensorData& p) {
|
||||||
|
if (p.tensors_.empty()) return false;
|
||||||
|
Metadata metadata;
|
||||||
|
if (!p.get_metadata(&metadata)) return false;
|
||||||
|
const bool validated = metadata.validated;
|
||||||
|
const DataType dtype = metadata.dtype;
|
||||||
|
|
||||||
|
// p.tensors_ should contain tensors {dense_shape, batch_pointers,
|
||||||
|
// row_pointers, col_indices, values}.
|
||||||
|
if (p.tensors_.size() != 5) return false;
|
||||||
|
|
||||||
|
Tensor dense_shape = p.tensors_[0];
|
||||||
|
if (dense_shape.dtype() != DT_INT64) return false;
|
||||||
|
if (dense_shape.dims() != 1) return false;
|
||||||
|
int rank = dense_shape.dim_size(0);
|
||||||
|
if (rank < 2 || rank > 3) return false;
|
||||||
|
|
||||||
|
Tensor batch_pointers(p.tensors_[1]);
|
||||||
|
Tensor row_pointers(p.tensors_[2]);
|
||||||
|
Tensor col_indices(p.tensors_[3]);
|
||||||
|
Tensor values(p.tensors_[4]);
|
||||||
|
|
||||||
|
// Check that the validated bool is consistent with the data.
|
||||||
|
Status s = ValidateTypesAndShapes(dtype, dense_shape, batch_pointers,
|
||||||
|
row_pointers, col_indices, values);
|
||||||
|
if (s.ok() != validated) return false;
|
||||||
|
|
||||||
|
// Save to this object.
|
||||||
|
metadata_ = metadata;
|
||||||
|
dense_shape_ = std::move(dense_shape);
|
||||||
|
batch_pointers_ = std::move(batch_pointers);
|
||||||
|
row_pointers_ = std::move(row_pointers);
|
||||||
|
col_indices_ = std::move(col_indices);
|
||||||
|
values_ = std::move(values);
|
||||||
|
SetupVecs();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Encode(VariantTensorData* p) const {
|
||||||
|
DCHECK(valid());
|
||||||
|
|
||||||
|
// Store metadata_ to p's metadata
|
||||||
|
p->set_metadata(metadata_);
|
||||||
|
|
||||||
|
// Store dense_shape, row_pointers, col_indices, and values to p->tensors_.
|
||||||
|
p->tensors_.reserve(5);
|
||||||
|
p->tensors_.push_back(dense_shape_);
|
||||||
|
p->tensors_.push_back(batch_pointers_);
|
||||||
|
p->tensors_.push_back(row_pointers_);
|
||||||
|
p->tensors_.push_back(col_indices_);
|
||||||
|
p->tensors_.push_back(values_);
|
||||||
|
}
|
||||||
|
|
||||||
|
// This static method copies CSRSparseMatrices in all directions:
|
||||||
|
// Host->Device, Device->Host, and Device->Device.
|
||||||
|
static Status DeviceCopy(
|
||||||
|
const CSRSparseMatrix& from, CSRSparseMatrix* to,
|
||||||
|
const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy) {
|
||||||
|
VLOG(2) << "DeviceCopy from type: " << DataTypeString(from.dtype())
|
||||||
|
<< " and shape: " << from.dense_shape().DebugString();
|
||||||
|
Tensor to_row_ptr(DT_INT32);
|
||||||
|
Tensor to_col_ind(DT_INT32);
|
||||||
|
Tensor to_values(from.dtype());
|
||||||
|
TF_RETURN_IF_ERROR(copy(from.row_pointers(), &to_row_ptr));
|
||||||
|
TF_RETURN_IF_ERROR(copy(from.col_indices(), &to_col_ind));
|
||||||
|
TF_RETURN_IF_ERROR(copy(from.values(), &to_values));
|
||||||
|
return CreateCSRSparseMatrix(from.dtype(),
|
||||||
|
from.dense_shape(), // Always on host.
|
||||||
|
from.batch_pointers(), // Always on host.
|
||||||
|
to_row_ptr, to_col_ind, to_values, to);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
CSRSparseMatrix(DataType dtype, const Tensor& dense_shape,
|
||||||
|
const Tensor& batch_pointers, const Tensor& row_pointers,
|
||||||
|
const Tensor& col_indices, const Tensor& values)
|
||||||
|
: metadata_{false, dtype},
|
||||||
|
dense_shape_(dense_shape),
|
||||||
|
batch_pointers_(batch_pointers),
|
||||||
|
row_pointers_(row_pointers),
|
||||||
|
col_indices_(col_indices),
|
||||||
|
values_(values) {}
|
||||||
|
|
||||||
|
void SetupVecs() {
|
||||||
|
if (!metadata_.validated) return;
|
||||||
|
batch_pointers_vec_.reset(
|
||||||
|
new TTypes<int32>::Vec(batch_pointers_.vec<int32>()));
|
||||||
|
row_pointers_vec_.reset(new TTypes<int32>::Vec(row_pointers_.vec<int32>()));
|
||||||
|
col_indices_vec_.reset(new TTypes<int32>::Vec(col_indices_.vec<int32>()));
|
||||||
|
}
|
||||||
|
|
||||||
|
void ClearVecs() {
|
||||||
|
batch_pointers_vec_.reset();
|
||||||
|
row_pointers_vec_.reset();
|
||||||
|
col_indices_vec_.reset();
|
||||||
|
}
|
||||||
|
|
||||||
|
static Status ValidateTypesAndShapes(DataType dtype,
|
||||||
|
const Tensor& dense_shape,
|
||||||
|
const Tensor& batch_pointers,
|
||||||
|
const Tensor& row_pointers,
|
||||||
|
const Tensor& col_indices,
|
||||||
|
const Tensor& values) {
|
||||||
|
// TODO(ebrevdo): Consider adding support for other floating point types
|
||||||
|
// (namely, float16).
|
||||||
|
if (dtype != DT_FLOAT && dtype != DT_DOUBLE && dtype != DT_COMPLEX64 &&
|
||||||
|
dtype != DT_COMPLEX128) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"CSRSparseMatrix::Validate: dtype = ", DataTypeString(dtype),
|
||||||
|
" not in {float32, float64, complex64, complex128}");
|
||||||
|
}
|
||||||
|
// dense_shape checks
|
||||||
|
if (dense_shape.dtype() != DT_INT64) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"CSRSparseMatrix::Validate: dense_shape.dtype() = ",
|
||||||
|
DataTypeString(dense_shape.dtype()), " != int64");
|
||||||
|
}
|
||||||
|
if (dense_shape.dims() != 1) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"CSRSparseMatrix::Validate: dense_shape should be a vector, but saw "
|
||||||
|
"tensor: ",
|
||||||
|
dense_shape.DebugString());
|
||||||
|
}
|
||||||
|
int rank = dense_shape.dim_size(0);
|
||||||
|
if (rank < 2 || rank > 3) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"CSRSparseMatrix::Validate: dense_shape should be a 2- or 3- vector, "
|
||||||
|
"but saw: ",
|
||||||
|
dense_shape.SummarizeValue(5));
|
||||||
|
}
|
||||||
|
auto dense_shape_t = dense_shape.vec<int64>();
|
||||||
|
int batch_size = (rank == 2) ? 1 : dense_shape_t(0);
|
||||||
|
|
||||||
|
if (batch_pointers.dtype() != DT_INT32) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"CSRSparseMatrix::Validate: batch_pointers.dtype() = ",
|
||||||
|
DataTypeString(batch_pointers.dtype()), " != int32");
|
||||||
|
}
|
||||||
|
if (batch_pointers.dims() != 1) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"CSRSparseMatrix::Validate: batch_indices is not a vector, saw "
|
||||||
|
"shape: ",
|
||||||
|
batch_pointers.shape().DebugString());
|
||||||
|
}
|
||||||
|
|
||||||
|
// batch size checks
|
||||||
|
if (batch_size != batch_pointers.NumElements() - 1) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"CSRSparseMatrix::Validate: dense_shape is ",
|
||||||
|
dense_shape.SummarizeValue(5),
|
||||||
|
" but batch pointers implies batch size is ",
|
||||||
|
batch_pointers.NumElements() - 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (row_pointers.dtype() != DT_INT32) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"CSRSparseMatrix::Validate: row_indices.dtype() = ",
|
||||||
|
DataTypeString(row_pointers.dtype()), " != int32");
|
||||||
|
}
|
||||||
|
if (row_pointers.dims() != 1) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"CSRSparseMatrix::Validate: row_indices is not a vector, saw shape: ",
|
||||||
|
row_pointers.shape().DebugString());
|
||||||
|
}
|
||||||
|
if (col_indices.dtype() != DT_INT32) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"CSRSparseMatrix::Validate: col_indices.dtype() = ",
|
||||||
|
DataTypeString(col_indices.dtype()), " != int32");
|
||||||
|
}
|
||||||
|
if (col_indices.dims() != 1) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"CSRSparseMatrix::Validate: col_indices is not a vector, saw shape: ",
|
||||||
|
col_indices.shape().DebugString());
|
||||||
|
}
|
||||||
|
if (values.dtype() != dtype) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"CSRSparseMatrix::Validate: values.dtype() = ",
|
||||||
|
DataTypeString(values.dtype()),
|
||||||
|
" != dtype = ", DataTypeString(dtype));
|
||||||
|
}
|
||||||
|
if (values.dims() != 1) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"CSRSparseMatrix::Validate: values is not a vector, saw shape: ",
|
||||||
|
values.shape().DebugString());
|
||||||
|
}
|
||||||
|
if (col_indices.dim_size(0) != values.dim_size(0)) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"CSRSparseMatrix::Validate: size(col_indices) = ",
|
||||||
|
col_indices.dim_size(0), " != size(values) = ", values.dim_size(0));
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Metadata {
|
||||||
|
bool validated;
|
||||||
|
DataType dtype;
|
||||||
|
};
|
||||||
|
Metadata metadata_;
|
||||||
|
Tensor dense_shape_;
|
||||||
|
Tensor batch_pointers_;
|
||||||
|
Tensor row_pointers_;
|
||||||
|
Tensor col_indices_;
|
||||||
|
Tensor values_;
|
||||||
|
std::unique_ptr<TTypes<int32>::Vec> batch_pointers_vec_;
|
||||||
|
std::unique_ptr<TTypes<int32>::Vec> row_pointers_vec_;
|
||||||
|
std::unique_ptr<TTypes<int32>::Vec> col_indices_vec_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Call BinaryFunctor<Device, T>()(ctx, a, b, c)
|
||||||
|
// where T depends on a.dtype(). T will be one of: float, double,
|
||||||
|
// complex64, complex128.
|
||||||
|
template <typename Device, template <typename, typename> class BinaryFunctor>
|
||||||
|
Status CSRSparseMatrixBinaryHelper(OpKernelContext* ctx,
|
||||||
|
const CSRSparseMatrix& a,
|
||||||
|
const CSRSparseMatrix& b,
|
||||||
|
CSRSparseMatrix* c) {
|
||||||
|
DataType dt = a.dtype();
|
||||||
|
if (dt != b.dtype()) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"CSRSparseMatrixBinaryHelper: Inconsistent dtypes for input matrices, "
|
||||||
|
"a "
|
||||||
|
"dtype: ",
|
||||||
|
DataTypeString(dt), ", b dtype: ", DataTypeString(b.dtype()));
|
||||||
|
}
|
||||||
|
switch (dt) {
|
||||||
|
case DT_FLOAT: {
|
||||||
|
BinaryFunctor<Device, float> functor(ctx);
|
||||||
|
return functor(a, b, c);
|
||||||
|
}
|
||||||
|
case DT_DOUBLE: {
|
||||||
|
BinaryFunctor<Device, double> functor(ctx);
|
||||||
|
return functor(a, b, c);
|
||||||
|
}
|
||||||
|
case DT_COMPLEX64: {
|
||||||
|
BinaryFunctor<Device, complex64> functor(ctx);
|
||||||
|
return functor(a, b, c);
|
||||||
|
}
|
||||||
|
case DT_COMPLEX128: {
|
||||||
|
BinaryFunctor<Device, complex128> functor(ctx);
|
||||||
|
return functor(a, b, c);
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"CSRSparseMatrixBinaryHelper: a.dtype (", DataTypeString(dt),
|
||||||
|
") is not one of: float, double, complex64, complex128");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call UnaryFunctor<Device, T>()(ctx, a, b)
|
||||||
|
// where T depends on a.dtype(). T will be one of: float, double,
|
||||||
|
// complex64, complex128.
|
||||||
|
template <typename Device, template <typename, typename> class UnaryFunctor>
|
||||||
|
Status CSRSparseMatrixUnaryHelper(OpKernelContext* ctx,
|
||||||
|
const CSRSparseMatrix& a,
|
||||||
|
CSRSparseMatrix* b) {
|
||||||
|
DataType dt = a.dtype();
|
||||||
|
switch (dt) {
|
||||||
|
case DT_FLOAT: {
|
||||||
|
UnaryFunctor<Device, float> functor(ctx);
|
||||||
|
return functor(a, b);
|
||||||
|
}
|
||||||
|
case DT_DOUBLE: {
|
||||||
|
UnaryFunctor<Device, double> functor(ctx);
|
||||||
|
return functor(a, b);
|
||||||
|
}
|
||||||
|
case DT_COMPLEX64: {
|
||||||
|
UnaryFunctor<Device, complex64> functor(ctx);
|
||||||
|
return functor(a, b);
|
||||||
|
}
|
||||||
|
case DT_COMPLEX128: {
|
||||||
|
UnaryFunctor<Device, complex128> functor(ctx);
|
||||||
|
return functor(a, b);
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"CSRSparseMatrixUnaryHelper: a.dtype (", DataTypeString(dt),
|
||||||
|
") is not one of: float, double, complex64, complex128");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct ConstCSRComponent {
|
||||||
|
TTypes<int32>::UnalignedConstVec row_ptr;
|
||||||
|
TTypes<int32>::UnalignedConstVec col_ind;
|
||||||
|
typename TTypes<T>::UnalignedConstVec values;
|
||||||
|
TTypes<int64>::ConstVec dense_shape_host;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct CSRComponent {
|
||||||
|
TTypes<int32>::UnalignedVec row_ptr;
|
||||||
|
TTypes<int32>::UnalignedVec col_ind;
|
||||||
|
typename TTypes<T>::UnalignedVec values;
|
||||||
|
TTypes<int64>::Vec dense_shape_host;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
Status ExtractVariantFromInput(OpKernelContext* ctx, int index,
|
||||||
|
const T** value) {
|
||||||
|
const Tensor& input_t = ctx->input(index);
|
||||||
|
const Variant& input_variant = input_t.scalar<Variant>()();
|
||||||
|
*value = input_variant.get<T>();
|
||||||
|
if (*value == nullptr) {
|
||||||
|
return errors::InvalidArgument("Could not retrieve Variant input ", index);
|
||||||
|
}
|
||||||
|
if (!(*value)->valid()) {
|
||||||
|
return errors::InvalidArgument("Variant input ", index, " is not valid.");
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_CORE_KERNELS_SPARSE_SPARSE_MATRIX_H_
|
150
tensorflow/core/kernels/sparse/sparse_matrix_components_op.cc
Normal file
150
tensorflow/core/kernels/sparse/sparse_matrix_components_op.cc
Normal file
@ -0,0 +1,150 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#define EIGEN_USE_THREADS
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
#define EIGEN_USE_GPU
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
|
#include "tensorflow/core/framework/op.h"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
|
#include "tensorflow/core/framework/variant_op_registry.h"
|
||||||
|
#include "tensorflow/core/kernels/dense_update_functor.h"
|
||||||
|
#include "tensorflow/core/kernels/slice_op.h"
|
||||||
|
#include "tensorflow/core/kernels/sparse/kernels.h"
|
||||||
|
#include "tensorflow/core/kernels/sparse/sparse_matrix.h"
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
#include "tensorflow/core/kernels/cuda_solvers.h"
|
||||||
|
#include "tensorflow/core/kernels/cuda_sparse.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||||
|
typedef Eigen::GpuDevice GPUDevice;
|
||||||
|
|
||||||
|
template <typename Device, typename T>
|
||||||
|
class CSRSparseMatrixComponentsOp : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit CSRSparseMatrixComponentsOp(OpKernelConstruction* c) : OpKernel(c) {}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* c) final {
|
||||||
|
const CSRSparseMatrix* csr_sparse_matrix;
|
||||||
|
OP_REQUIRES_OK(c, ExtractVariantFromInput(c, 0, &csr_sparse_matrix));
|
||||||
|
|
||||||
|
const Tensor& index_t = c->input(1);
|
||||||
|
OP_REQUIRES(c, DataTypeToEnum<T>::value == csr_sparse_matrix->dtype(),
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"dtype of input is not equal to 'type': ",
|
||||||
|
DataTypeString(csr_sparse_matrix->dtype()), " vs. ",
|
||||||
|
DataTypeString(DataTypeToEnum<T>::value)));
|
||||||
|
OP_REQUIRES(c, index_t.dims() == 0,
|
||||||
|
errors::InvalidArgument("index should be a scalar, but saw: ",
|
||||||
|
index_t.DebugString()));
|
||||||
|
int32 index = index_t.scalar<int32>()();
|
||||||
|
OP_REQUIRES(c, index >= 0 && index < csr_sparse_matrix->batch_size(),
|
||||||
|
errors::InvalidArgument("index (", index, ") not in [0, ",
|
||||||
|
csr_sparse_matrix->batch_size(), ")"));
|
||||||
|
|
||||||
|
if (csr_sparse_matrix->dims() == 2) {
|
||||||
|
c->set_output(0, csr_sparse_matrix->row_pointers());
|
||||||
|
c->set_output(1, csr_sparse_matrix->col_indices());
|
||||||
|
c->set_output(2, csr_sparse_matrix->values());
|
||||||
|
} else {
|
||||||
|
auto batch_ptrs = csr_sparse_matrix->batch_pointers().vec<int32>();
|
||||||
|
auto dense_shape = csr_sparse_matrix->dense_shape().vec<int64>();
|
||||||
|
int64 rows = dense_shape(1);
|
||||||
|
int nnz = batch_ptrs(index + 1) - batch_ptrs(index);
|
||||||
|
Tensor* row_ptrs_t;
|
||||||
|
Tensor* col_inds_t;
|
||||||
|
Tensor* values_t;
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
c, c->allocate_output(0, TensorShape({rows + 1}), &row_ptrs_t));
|
||||||
|
OP_REQUIRES_OK(c, c->allocate_output(1, TensorShape({nnz}), &col_inds_t));
|
||||||
|
OP_REQUIRES_OK(c, c->allocate_output(2, TensorShape({nnz}), &values_t));
|
||||||
|
auto row_ptrs = row_ptrs_t->vec<int32>();
|
||||||
|
auto col_inds = col_inds_t->vec<int32>();
|
||||||
|
auto values = values_t->vec<T>();
|
||||||
|
|
||||||
|
functor::Slice<Device, int32, 1> slice_int;
|
||||||
|
functor::Slice<Device, T, 1> slice_t;
|
||||||
|
typedef Eigen::DSizes<Eigen::DenseIndex, 1> EVec;
|
||||||
|
const Device& d = c->eigen_device<Device>();
|
||||||
|
slice_int(d,
|
||||||
|
/*output*/ row_ptrs,
|
||||||
|
/*input*/ csr_sparse_matrix->row_pointers().vec<int32>(),
|
||||||
|
/*slice_indices*/ EVec{index * (rows + 1)},
|
||||||
|
/*slice_sizes*/ EVec{rows + 1});
|
||||||
|
slice_int(d,
|
||||||
|
/*output*/ col_inds,
|
||||||
|
/*input*/ csr_sparse_matrix->col_indices().vec<int32>(),
|
||||||
|
/*slice_indices*/ EVec{batch_ptrs(index)},
|
||||||
|
/*slice_sizes*/ EVec{nnz});
|
||||||
|
slice_t(d,
|
||||||
|
/*output*/ values, /*input*/ csr_sparse_matrix->values().vec<T>(),
|
||||||
|
/*slice_indices*/ EVec{batch_ptrs(index)},
|
||||||
|
/*slice_sizes*/ EVec{nnz});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
#define REGISTER(DEV, T) \
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("CSRSparseMatrixComponents") \
|
||||||
|
.Device(DEVICE_##DEV) \
|
||||||
|
.TypeConstraint<T>("type") \
|
||||||
|
.HostMemory("index"), \
|
||||||
|
CSRSparseMatrixComponentsOp<DEV##Device, T>);
|
||||||
|
|
||||||
|
REGISTER(CPU, float)
|
||||||
|
REGISTER(CPU, double)
|
||||||
|
REGISTER(CPU, complex64)
|
||||||
|
REGISTER(CPU, complex128)
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
|
||||||
|
REGISTER(GPU, float)
|
||||||
|
REGISTER(GPU, double)
|
||||||
|
REGISTER(GPU, complex64)
|
||||||
|
REGISTER(GPU, complex128)
|
||||||
|
|
||||||
|
#undef REGISTER
|
||||||
|
|
||||||
|
namespace functor {
|
||||||
|
// TODO(ebrevdo): This should move to a slice_functor.cc
|
||||||
|
#define DECLARE_GPU_SPEC(T) \
|
||||||
|
template <> \
|
||||||
|
void Slice<GPUDevice, T, 1>::operator()( \
|
||||||
|
const GPUDevice& d, typename TTypes<T, 1>::Tensor output, \
|
||||||
|
typename TTypes<T, 1>::ConstTensor input, \
|
||||||
|
const Eigen::DSizes<Eigen::DenseIndex, 1>& indices, \
|
||||||
|
const Eigen::DSizes<Eigen::DenseIndex, 1>& sizes); \
|
||||||
|
extern template struct Slice<GPUDevice, T, 1>;
|
||||||
|
|
||||||
|
DECLARE_GPU_SPEC(int32);
|
||||||
|
DECLARE_GPU_SPEC(float);
|
||||||
|
DECLARE_GPU_SPEC(double);
|
||||||
|
DECLARE_GPU_SPEC(complex64);
|
||||||
|
DECLARE_GPU_SPEC(complex128);
|
||||||
|
|
||||||
|
#undef DECLARE_GPU_SPEC
|
||||||
|
} // namespace functor
|
||||||
|
|
||||||
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
131
tensorflow/core/kernels/sparse/sparse_ordering_amd_op.cc
Normal file
131
tensorflow/core/kernels/sparse/sparse_ordering_amd_op.cc
Normal file
@ -0,0 +1,131 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#define EIGEN_USE_THREADS
|
||||||
|
|
||||||
|
#include "third_party/eigen3/Eigen/Core"
|
||||||
|
#include "third_party/eigen3/Eigen/SparseCholesky"
|
||||||
|
#include "third_party/eigen3/Eigen/SparseCore"
|
||||||
|
#include "third_party/eigen3/Eigen/OrderingMethods"
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
|
#include "tensorflow/core/framework/allocator.h"
|
||||||
|
#include "tensorflow/core/framework/op.h"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/variant_op_registry.h"
|
||||||
|
#include "tensorflow/core/kernels/sparse/kernels.h"
|
||||||
|
#include "tensorflow/core/kernels/sparse/sparse_matrix.h"
|
||||||
|
#include "tensorflow/core/util/work_sharder.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// Op to compute the Approximate Minimum Degree (AMD) ordering for a sparse
|
||||||
|
// matrix.
|
||||||
|
//
|
||||||
|
// Accepts a CSRSparseMatrix which may represent a single sparse matrix (rank 2)
|
||||||
|
// or a batch of sparse matrices (rank 3). Each component must be a square
|
||||||
|
// matrix. The input is assumed to be symmetric; only the lower triangular part
|
||||||
|
// of each component matrix is read. The numeric values of the sparse matrix
|
||||||
|
// does not affect the returned AMD ordering; only the sparsity pattern does.
|
||||||
|
//
|
||||||
|
// For each component sparse matrix A, the corresponding output Tensor
|
||||||
|
// represents the AMD ordering of A's rows and columns. The ordering is returned
|
||||||
|
// as a 1D Tensor (per batch) containing the list of indices, i.e. it contains
|
||||||
|
// each of the integers {0, .. N-1} exactly once; where N is the number of rows
|
||||||
|
// of the sparse matrix. The ith element represents the index of the row that
|
||||||
|
// the ith row should map to.
|
||||||
|
|
||||||
|
// If P represents the permutation matrix corresponding to the indices, then the
|
||||||
|
// matrix:
|
||||||
|
// P^{-1} * A * P
|
||||||
|
// would have a sparse Cholesky decomposition with fewer structural non-zero
|
||||||
|
// elements than the sparse Cholesky decomposition of A itself.
|
||||||
|
class CSROrderingAMDCPUOp : public OpKernel {
|
||||||
|
using SparseMatrix = Eigen::SparseMatrix<int, Eigen::RowMajor>;
|
||||||
|
using Indices =
|
||||||
|
Eigen::Matrix<int, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
|
||||||
|
using IndicesMap = Eigen::Map<Indices>;
|
||||||
|
using ConstIndicesMap = Eigen::Map<const Indices>;
|
||||||
|
|
||||||
|
public:
|
||||||
|
explicit CSROrderingAMDCPUOp(OpKernelConstruction* c) : OpKernel(c) {}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* ctx) final {
|
||||||
|
// Extract the input CSRSparseMatrix.
|
||||||
|
const CSRSparseMatrix* input_matrix;
|
||||||
|
OP_REQUIRES_OK(ctx, ExtractVariantFromInput(ctx, 0, &input_matrix));
|
||||||
|
|
||||||
|
const Tensor& dense_shape = input_matrix->dense_shape();
|
||||||
|
const int rank = dense_shape.dim_size(0);
|
||||||
|
OP_REQUIRES(ctx, rank == 2 || rank == 3,
|
||||||
|
errors::InvalidArgument("sparse matrix must have rank 2 or 3; ",
|
||||||
|
"but dense_shape has size ", rank));
|
||||||
|
|
||||||
|
auto dense_shape_vec = dense_shape.vec<int64>();
|
||||||
|
const int64 num_rows = dense_shape_vec((rank == 2) ? 0 : 1);
|
||||||
|
const int64 num_cols = dense_shape_vec((rank == 2) ? 1 : 2);
|
||||||
|
|
||||||
|
OP_REQUIRES(ctx, num_rows == num_cols,
|
||||||
|
errors::InvalidArgument("sparse matrix must be square; got: ",
|
||||||
|
num_rows, " != ", num_cols));
|
||||||
|
|
||||||
|
// Allocate the output permutation indices.
|
||||||
|
const int batch_size = input_matrix->batch_size();
|
||||||
|
TensorShape permutation_indices_shape =
|
||||||
|
(rank == 2) ? TensorShape{num_rows} : TensorShape{batch_size, num_rows};
|
||||||
|
Tensor permutation_indices(cpu_allocator(), DT_INT32,
|
||||||
|
permutation_indices_shape);
|
||||||
|
ctx->set_output(0, permutation_indices);
|
||||||
|
|
||||||
|
// Parallelize AMD computation across batches using a threadpool.
|
||||||
|
auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
|
||||||
|
const int64 amd_cost_per_batch =
|
||||||
|
10 * num_rows * (input_matrix->total_nnz() / batch_size);
|
||||||
|
Shard(
|
||||||
|
worker_threads.num_threads, worker_threads.workers, batch_size,
|
||||||
|
amd_cost_per_batch, [&](int64 batch_begin, int64 batch_end) {
|
||||||
|
for (int64 batch_index = batch_begin; batch_index < batch_end;
|
||||||
|
++batch_index) {
|
||||||
|
// Define an Eigen SparseMatrix Map to operate on the
|
||||||
|
// CSRSparseMatrix component without copying the data.
|
||||||
|
// The values doesn't matter for computing the ordering, hence we
|
||||||
|
// reuse the column pointers as dummy values.
|
||||||
|
Eigen::Map<const SparseMatrix> sparse_matrix(
|
||||||
|
num_rows, num_rows, input_matrix->nnz(batch_index),
|
||||||
|
input_matrix->row_pointers_vec(batch_index).data(),
|
||||||
|
input_matrix->col_indices_vec(batch_index).data(),
|
||||||
|
input_matrix->col_indices_vec(batch_index).data());
|
||||||
|
Eigen::PermutationMatrix<Eigen::Dynamic, Eigen::Dynamic, int>
|
||||||
|
permutation_matrix;
|
||||||
|
// Compute the AMD ordering.
|
||||||
|
Eigen::AMDOrdering<int> amd_ordering;
|
||||||
|
amd_ordering(sparse_matrix.template selfadjointView<Eigen::Lower>(),
|
||||||
|
permutation_matrix);
|
||||||
|
// Define an Eigen Map over the allocated output Tensor so that it
|
||||||
|
// can be mutated in place.
|
||||||
|
IndicesMap permutation_map(
|
||||||
|
permutation_indices.flat<int>().data() + batch_index * num_rows,
|
||||||
|
num_rows, 1);
|
||||||
|
permutation_map = permutation_matrix.indices();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("SparseMatrixOrderingAMD").Device(DEVICE_CPU),
|
||||||
|
CSROrderingAMDCPUOp);
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
@ -0,0 +1,345 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#define EIGEN_USE_THREADS
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
#define EIGEN_USE_GPU
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
|
#include "tensorflow/core/framework/op.h"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
|
#include "tensorflow/core/framework/variant_op_registry.h"
|
||||||
|
#include "tensorflow/core/kernels/dense_update_functor.h"
|
||||||
|
#include "tensorflow/core/kernels/fill_functor.h"
|
||||||
|
#include "tensorflow/core/kernels/sparse/kernels.h"
|
||||||
|
#include "tensorflow/core/kernels/sparse/sparse_matrix.h"
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
|
||||||
|
#include "tensorflow/core/kernels/cuda_solvers.h"
|
||||||
|
#include "tensorflow/core/kernels/cuda_sparse.h"
|
||||||
|
#include "tensorflow/core/platform/cuda.h"
|
||||||
|
|
||||||
|
using ::perftools::gputools::cuda::ScopedActivateExecutorContext;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||||
|
typedef Eigen::GpuDevice GPUDevice;
|
||||||
|
|
||||||
|
// Op to convert SparseTensors to CSR SparseMatrices on the CPU.
|
||||||
|
// Takes a SparseTensor of rank 2 or (if batched) 3 as the input. The
|
||||||
|
// SparseTensor's indices must be present in the canonical, row-major ordering.
|
||||||
|
//
|
||||||
|
// Returns a (batched) CSR SparseMatrix with the same dense shape and non-zero
|
||||||
|
// values.
|
||||||
|
template <typename T>
|
||||||
|
class SparseTensorToCSRSparseMatrixCPUOp : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit SparseTensorToCSRSparseMatrixCPUOp(OpKernelConstruction* c)
|
||||||
|
: OpKernel(c) {}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* ctx) final {
|
||||||
|
const Tensor& indices = ctx->input(0);
|
||||||
|
const Tensor& values = ctx->input(1);
|
||||||
|
const Tensor& dense_shape = ctx->input(2);
|
||||||
|
const int rank = dense_shape.NumElements();
|
||||||
|
OP_REQUIRES(ctx, rank == 2 || rank == 3,
|
||||||
|
errors::InvalidArgument("SparseTensor must have rank 2 or 3; ",
|
||||||
|
"but indices has rank: ", rank));
|
||||||
|
auto dense_shape_vec = dense_shape.vec<int64>();
|
||||||
|
const int64 batch_size = (rank == 2) ? 1 : dense_shape_vec(0);
|
||||||
|
const int64 num_rows = dense_shape_vec((rank == 2) ? 0 : 1);
|
||||||
|
const int64 total_nnz = values.NumElements();
|
||||||
|
|
||||||
|
// Allocate output Tensors.
|
||||||
|
Tensor batch_ptr(cpu_allocator(), DT_INT32, TensorShape({batch_size + 1}));
|
||||||
|
Tensor csr_col_ind(cpu_allocator(), DT_INT32, TensorShape({total_nnz}));
|
||||||
|
Tensor csr_row_ptr(cpu_allocator(), DT_INT32,
|
||||||
|
TensorShape({(num_rows + 1) * batch_size}));
|
||||||
|
|
||||||
|
// Fill the row pointers with zeros.
|
||||||
|
functor::SetZeroFunctor<CPUDevice, int32> set_zero;
|
||||||
|
set_zero(ctx->eigen_device<CPUDevice>(), csr_row_ptr.flat<int32>());
|
||||||
|
|
||||||
|
// Convert from COO to CSR format.
|
||||||
|
functor::SparseTensorToCSRSparseMatrixCPUFunctor coo_to_csr;
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
ctx, coo_to_csr(batch_size, num_rows, indices.template matrix<int64>(),
|
||||||
|
batch_ptr.vec<int32>(), csr_row_ptr.vec<int32>(),
|
||||||
|
csr_col_ind.vec<int32>()));
|
||||||
|
|
||||||
|
// Create the CSRSparseMatrix object from its component Tensors and prepare
|
||||||
|
// the Variant output Tensor.
|
||||||
|
CSRSparseMatrix output_csr_matrix;
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
ctx, CSRSparseMatrix::CreateCSRSparseMatrix(
|
||||||
|
DataTypeToEnum<T>::value, dense_shape, batch_ptr, csr_row_ptr,
|
||||||
|
csr_col_ind, values, &output_csr_matrix));
|
||||||
|
Tensor* output_csr_matrix_tensor;
|
||||||
|
AllocatorAttributes cpu_alloc;
|
||||||
|
cpu_alloc.set_on_host(true);
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
ctx, ctx->allocate_output(0, TensorShape({}), &output_csr_matrix_tensor,
|
||||||
|
cpu_alloc));
|
||||||
|
output_csr_matrix_tensor->scalar<Variant>()() =
|
||||||
|
std::move(output_csr_matrix);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
|
||||||
|
template <typename Device, typename T>
|
||||||
|
class SparseTensorToCSRSparseMatrixGPUOp : public AsyncOpKernel {
|
||||||
|
public:
|
||||||
|
explicit SparseTensorToCSRSparseMatrixGPUOp(OpKernelConstruction* c)
|
||||||
|
: AsyncOpKernel(c) {}
|
||||||
|
|
||||||
|
void ComputeAsync(OpKernelContext* c, DoneCallback done) final {
|
||||||
|
auto stream = c->op_device_context()->stream();
|
||||||
|
const Device& d = c->eigen_device<Device>();
|
||||||
|
|
||||||
|
const Tensor& indices_t = c->input(0);
|
||||||
|
const Tensor& values_t = c->input(1);
|
||||||
|
const Tensor& dense_shape_t = c->input(2);
|
||||||
|
const int rank = dense_shape_t.NumElements();
|
||||||
|
OP_REQUIRES_ASYNC(
|
||||||
|
c, rank == 2 || rank == 3,
|
||||||
|
errors::InvalidArgument("sparse tensor must have rank == 2 or 3; ",
|
||||||
|
"but indices has ", rank, " columns"),
|
||||||
|
done);
|
||||||
|
auto dense_shape = dense_shape_t.vec<int64>();
|
||||||
|
const int64 batch_size = (rank == 2) ? 1 : dense_shape(0);
|
||||||
|
const int64 rows = dense_shape((rank == 2) ? 0 : 1);
|
||||||
|
const int64 cols = dense_shape((rank == 2) ? 1 : 2);
|
||||||
|
|
||||||
|
ScratchSpace<int32> nnz_per_batch_host(c, batch_size, /*on_host*/ true);
|
||||||
|
|
||||||
|
Tensor nnz_per_batch_device_t;
|
||||||
|
if (rank == 2) {
|
||||||
|
// Simple case.
|
||||||
|
nnz_per_batch_host.mutable_data()[0] = indices_t.dim_size(0);
|
||||||
|
} else {
|
||||||
|
OP_REQUIRES_OK_ASYNC(c,
|
||||||
|
c->allocate_temp(DT_INT32, TensorShape({batch_size}),
|
||||||
|
&nnz_per_batch_device_t),
|
||||||
|
done);
|
||||||
|
auto nnz_per_batch_device = nnz_per_batch_device_t.vec<int32>();
|
||||||
|
|
||||||
|
functor::CalculateNNZPerBatchMatrixFromIndices<Device>
|
||||||
|
calculate_nnz_from_indices;
|
||||||
|
auto indices = indices_t.matrix<int64>();
|
||||||
|
OP_REQUIRES_OK_ASYNC(
|
||||||
|
c, calculate_nnz_from_indices(c, indices, nnz_per_batch_device),
|
||||||
|
done);
|
||||||
|
|
||||||
|
perftools::gputools::DeviceMemoryBase nnz_per_batch_device_ptr(
|
||||||
|
static_cast<void*>(nnz_per_batch_device.data()));
|
||||||
|
|
||||||
|
OP_REQUIRES_ASYNC(
|
||||||
|
c,
|
||||||
|
stream
|
||||||
|
->ThenMemcpy(nnz_per_batch_host.mutable_data() /*host_dst*/,
|
||||||
|
nnz_per_batch_device_ptr /*gpu_src*/,
|
||||||
|
batch_size * sizeof(int32) /*size*/)
|
||||||
|
.ok(),
|
||||||
|
errors::Internal("SparseTensorToSparseMatrixGPUOp: failed to copy "
|
||||||
|
"nnz_per_batch from device"),
|
||||||
|
done);
|
||||||
|
}
|
||||||
|
|
||||||
|
TensorReference nnz_per_batch_device_ref(nnz_per_batch_device_t);
|
||||||
|
auto convert_to_csr = [this, c, batch_size, nnz_per_batch_host,
|
||||||
|
nnz_per_batch_device_ref, stream, &d, &values_t,
|
||||||
|
&indices_t, &dense_shape_t, dense_shape, rows, cols,
|
||||||
|
rank, done]() {
|
||||||
|
// The data has been copied out of the nnz_per_batch_device
|
||||||
|
// tensor by the time we get here; we can unreference it.
|
||||||
|
nnz_per_batch_device_ref.Unref();
|
||||||
|
|
||||||
|
auto nnz_per_batch = nnz_per_batch_host.tensor().vec<int32>();
|
||||||
|
|
||||||
|
// Ensure that within the callback, the proper GPU settings are
|
||||||
|
// configured.
|
||||||
|
ScopedActivateExecutorContext scoped_activation{stream->parent()};
|
||||||
|
Tensor batch_ptr_t(cpu_allocator(), DT_INT32,
|
||||||
|
TensorShape({batch_size + 1}));
|
||||||
|
|
||||||
|
auto batch_ptr = batch_ptr_t.vec<int32>();
|
||||||
|
auto indices = indices_t.matrix<int64>();
|
||||||
|
|
||||||
|
batch_ptr(0) = 0;
|
||||||
|
for (int i = 0; i < batch_size; ++i) {
|
||||||
|
batch_ptr(i + 1) = batch_ptr(i) + nnz_per_batch(i);
|
||||||
|
}
|
||||||
|
int total_nnz = batch_ptr(batch_size);
|
||||||
|
OP_REQUIRES_ASYNC(
|
||||||
|
c, total_nnz == values_t.NumElements(),
|
||||||
|
errors::Internal("nnz returned by "
|
||||||
|
"CalculateNNZPerBatchMatrixFromInd"
|
||||||
|
"ices != len(values): ",
|
||||||
|
total_nnz, " vs. ", values_t.NumElements()),
|
||||||
|
done);
|
||||||
|
|
||||||
|
Tensor coo_col_ind_t;
|
||||||
|
Tensor csr_row_ptr_t;
|
||||||
|
Tensor csr_values_t = values_t;
|
||||||
|
|
||||||
|
Tensor coo_row_ind_t;
|
||||||
|
OP_REQUIRES_OK_ASYNC(
|
||||||
|
c,
|
||||||
|
c->allocate_temp(DT_INT32, TensorShape({total_nnz}), &coo_row_ind_t),
|
||||||
|
done);
|
||||||
|
OP_REQUIRES_OK_ASYNC(
|
||||||
|
c,
|
||||||
|
c->allocate_temp(DT_INT32, TensorShape({total_nnz}), &coo_col_ind_t),
|
||||||
|
done);
|
||||||
|
OP_REQUIRES_OK_ASYNC(
|
||||||
|
c,
|
||||||
|
c->allocate_temp(DT_INT32, TensorShape({batch_size * (rows + 1)}),
|
||||||
|
&csr_row_ptr_t),
|
||||||
|
done);
|
||||||
|
|
||||||
|
auto coo_row_ind = coo_row_ind_t.vec<int32>();
|
||||||
|
auto coo_col_ind = coo_col_ind_t.vec<int32>();
|
||||||
|
auto csr_row_ptr = csr_row_ptr_t.vec<int32>();
|
||||||
|
|
||||||
|
// Convert SparseTensor rep to coo row ind, coo col ind.
|
||||||
|
if (total_nnz > 0) {
|
||||||
|
functor::SparseTensorToCOOSparseMatrix<Device> st_to_coo;
|
||||||
|
st_to_coo(d, dense_shape, indices, coo_row_ind, coo_col_ind);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set all csr row pointers to zero, so that when iterating over
|
||||||
|
// batches converting coo to csr, we do not have to perform an
|
||||||
|
// unaligned SetZero for any nnz == 0 minibatches. coo2csr has
|
||||||
|
// a bug if you have empty coo rows.
|
||||||
|
// TODO(ebrevdo): File bug w/ nvidia so coo2csr can handle
|
||||||
|
// zero-element input coo rows.
|
||||||
|
functor::SetZeroFunctor<Device, int32> set_zero;
|
||||||
|
set_zero(d, csr_row_ptr_t.flat<int32>());
|
||||||
|
|
||||||
|
functor::COOSparseMatrixToCSRSparseMatrix<Device> coo_to_csr;
|
||||||
|
for (int i = 0; i < batch_size; ++i) {
|
||||||
|
int nnz_i = batch_ptr(i + 1) - batch_ptr(i);
|
||||||
|
if (nnz_i == 0) {
|
||||||
|
// This is an empty minibatch; no call to coo2csr: it's
|
||||||
|
// handled by the SetZero above.
|
||||||
|
} else {
|
||||||
|
// Convert coo to csr.
|
||||||
|
auto coo_row_ind_i =
|
||||||
|
TTypes<int32>::UnalignedVec(&coo_row_ind(batch_ptr(i)), nnz_i);
|
||||||
|
auto csr_row_ptr_i = TTypes<int32>::UnalignedVec(
|
||||||
|
&csr_row_ptr((rows + 1) * i), rows + 1);
|
||||||
|
OP_REQUIRES_OK_ASYNC(
|
||||||
|
c, coo_to_csr(c, rows, cols, coo_row_ind_i, csr_row_ptr_i), done);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
CSRSparseMatrix matrix;
|
||||||
|
OP_REQUIRES_OK_ASYNC(
|
||||||
|
c,
|
||||||
|
CSRSparseMatrix::CreateCSRSparseMatrix(
|
||||||
|
values_t.dtype(), dense_shape_t, batch_ptr_t, csr_row_ptr_t,
|
||||||
|
coo_col_ind_t, csr_values_t, &matrix),
|
||||||
|
done);
|
||||||
|
Tensor* matrix_t;
|
||||||
|
AllocatorAttributes cpu_alloc;
|
||||||
|
cpu_alloc.set_on_host(true);
|
||||||
|
OP_REQUIRES_OK_ASYNC(
|
||||||
|
c, c->allocate_output(0, TensorShape({}), &matrix_t, cpu_alloc),
|
||||||
|
done);
|
||||||
|
matrix_t->scalar<Variant>()() = std::move(matrix);
|
||||||
|
|
||||||
|
done();
|
||||||
|
};
|
||||||
|
|
||||||
|
if (rank == 2) {
|
||||||
|
convert_to_csr();
|
||||||
|
} else {
|
||||||
|
// Launch the GPU kernel to count nnz entries, then call convert_to_csr.
|
||||||
|
c->device()->tensorflow_gpu_device_info()->event_mgr->ThenExecute(
|
||||||
|
stream, convert_to_csr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
namespace functor {
|
||||||
|
|
||||||
|
template <>
|
||||||
|
Status CalculateNNZPerBatchMatrixFromIndices<GPUDevice>::operator()(
|
||||||
|
OpKernelContext* c, TTypes<int64>::ConstMatrix indices,
|
||||||
|
TTypes<int32>::Vec nnz_per_batch);
|
||||||
|
extern template struct CalculateNNZPerBatchMatrixFromIndices<GPUDevice>;
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct SparseTensorToCOOSparseMatrix<GPUDevice> {
|
||||||
|
void operator()(const GPUDevice& d, TTypes<int64>::ConstVec host_dense_shape,
|
||||||
|
TTypes<int64>::ConstMatrix indices,
|
||||||
|
TTypes<int>::Vec coo_row_ind, TTypes<int>::Vec coo_col_ind);
|
||||||
|
};
|
||||||
|
extern template struct SparseTensorToCOOSparseMatrix<GPUDevice>;
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct COOSparseMatrixToCSRSparseMatrix<GPUDevice> {
|
||||||
|
Status operator()(OpKernelContext* c, const int rows, const int cols,
|
||||||
|
TTypes<int>::UnalignedVec coo_row_ind,
|
||||||
|
TTypes<int>::UnalignedVec csr_row_ptr) {
|
||||||
|
CudaSparse cuda_sparse(c);
|
||||||
|
TF_RETURN_IF_ERROR(cuda_sparse.Initialize());
|
||||||
|
return cuda_sparse.Coo2csr(coo_row_ind.data(),
|
||||||
|
/*nnz*/ coo_row_ind.size(),
|
||||||
|
/*m == rows of A*/ rows, csr_row_ptr.data());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
extern template struct COOSparseMatrixToCSRSparseMatrix<GPUDevice>;
|
||||||
|
|
||||||
|
} // namespace functor
|
||||||
|
|
||||||
|
#define REGISTER_GPU(T) \
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("SparseTensorToCSRSparseMatrix") \
|
||||||
|
.Device(DEVICE_GPU) \
|
||||||
|
.TypeConstraint<T>("T") \
|
||||||
|
.HostMemory("dense_shape"), \
|
||||||
|
SparseTensorToCSRSparseMatrixGPUOp<GPUDevice, T>);
|
||||||
|
|
||||||
|
REGISTER_GPU(float)
|
||||||
|
REGISTER_GPU(double)
|
||||||
|
REGISTER_GPU(complex64)
|
||||||
|
REGISTER_GPU(complex128)
|
||||||
|
|
||||||
|
#undef REGISTER_GPU
|
||||||
|
|
||||||
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
|
#define REGISTER_CPU(T) \
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("SparseTensorToCSRSparseMatrix") \
|
||||||
|
.Device(DEVICE_CPU) \
|
||||||
|
.TypeConstraint<T>("T"), \
|
||||||
|
SparseTensorToCSRSparseMatrixCPUOp<T>);
|
||||||
|
|
||||||
|
REGISTER_CPU(float)
|
||||||
|
REGISTER_CPU(double)
|
||||||
|
REGISTER_CPU(complex64)
|
||||||
|
REGISTER_CPU(complex128)
|
||||||
|
|
||||||
|
#undef REGISTER_CPU
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
189
tensorflow/core/kernels/sparse/transpose_op.cc
Normal file
189
tensorflow/core/kernels/sparse/transpose_op.cc
Normal file
@ -0,0 +1,189 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
// Implements the kernel for the CSRTranspose op, which transposes the
|
||||||
|
// two innermost dimensions of a CSRSparseMatrix object stored in a
|
||||||
|
// DT_VARIANT.
|
||||||
|
|
||||||
|
#define EIGEN_USE_THREADS
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
#include "tensorflow/core/kernels/cuda_sparse.h"
|
||||||
|
#define EIGEN_USE_GPU
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "tensorflow/core/kernels/sparse/transpose_op.h"
|
||||||
|
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
|
#include "tensorflow/core/framework/op.h"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
|
#include "tensorflow/core/framework/variant_op_registry.h"
|
||||||
|
#include "tensorflow/core/kernels/cwise_ops.h"
|
||||||
|
#include "tensorflow/core/kernels/dense_update_functor.h"
|
||||||
|
#include "tensorflow/core/kernels/fill_functor.h"
|
||||||
|
#include "tensorflow/core/kernels/slice_op.h"
|
||||||
|
#include "tensorflow/core/kernels/sparse/kernels.h"
|
||||||
|
#include "tensorflow/core/kernels/sparse/sparse_matrix.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||||
|
typedef Eigen::GpuDevice GPUDevice;
|
||||||
|
|
||||||
|
template <typename Device, typename T>
|
||||||
|
class CSRTransposeOp : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit CSRTransposeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("conjugate", &conjugate_));
|
||||||
|
}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* ctx) override {
|
||||||
|
const CSRSparseMatrix* input_matrix;
|
||||||
|
OP_REQUIRES_OK(ctx, ExtractVariantFromInput(ctx, 0, &input_matrix));
|
||||||
|
OP_REQUIRES(
|
||||||
|
ctx, input_matrix->dtype() == DataTypeToEnum<T>::value,
|
||||||
|
errors::InvalidArgument("dtype of input is not equal to 'type': ",
|
||||||
|
DataTypeString(input_matrix->dtype()), " vs. ",
|
||||||
|
DataTypeString(DataTypeToEnum<T>::value)));
|
||||||
|
|
||||||
|
// Allocate output shapes
|
||||||
|
functor::CSRSparseMatrixTranspose<Device, T> transpose;
|
||||||
|
CSRSparseMatrix output_matrix;
|
||||||
|
OP_REQUIRES_OK(ctx,
|
||||||
|
transpose(ctx, conjugate_, *input_matrix, &output_matrix));
|
||||||
|
Tensor output_t(cpu_allocator(), DT_VARIANT, TensorShape({}));
|
||||||
|
output_t.scalar<Variant>()() = std::move(output_matrix);
|
||||||
|
ctx->set_output(0, output_t);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool conjugate_;
|
||||||
|
};
|
||||||
|
|
||||||
|
#ifdef GOOGLE_CUDA
|
||||||
|
#define REGISTER(DEV, T) \
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("SparseMatrixTranspose") \
|
||||||
|
.Device(DEVICE_##DEV) \
|
||||||
|
.TypeConstraint<T>("type"), \
|
||||||
|
CSRTransposeOp<DEV##Device, T>);
|
||||||
|
|
||||||
|
REGISTER(GPU, float)
|
||||||
|
REGISTER(GPU, double)
|
||||||
|
REGISTER(GPU, complex64)
|
||||||
|
REGISTER(GPU, complex128)
|
||||||
|
|
||||||
|
#undef REGISTER
|
||||||
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
|
namespace functor {
|
||||||
|
|
||||||
|
template <typename Device, typename T>
|
||||||
|
Status CSRSparseMatrixTranspose<Device, T>::operator()(
|
||||||
|
OpKernelContext* ctx, bool conjugate, const CSRSparseMatrix& input_matrix,
|
||||||
|
CSRSparseMatrix* output_matrix) {
|
||||||
|
const int rank = input_matrix.dims();
|
||||||
|
Tensor output_dense_shape_t(cpu_allocator(), DT_INT64, TensorShape({rank}));
|
||||||
|
const Tensor& input_dense_shape_t = input_matrix.dense_shape();
|
||||||
|
auto input_dense_shape = input_dense_shape_t.vec<int64>();
|
||||||
|
auto output_dense_shape = output_dense_shape_t.vec<int64>();
|
||||||
|
const int64 batch_size = input_matrix.batch_size();
|
||||||
|
if (rank == 3) {
|
||||||
|
output_dense_shape(0) = batch_size;
|
||||||
|
}
|
||||||
|
output_dense_shape(rank - 2) = input_dense_shape(rank - 1);
|
||||||
|
output_dense_shape(rank - 1) = input_dense_shape(rank - 2);
|
||||||
|
const int64 output_rows = output_dense_shape(rank - 2);
|
||||||
|
|
||||||
|
// nnzs per batch do not change with matrix transposition.
|
||||||
|
Tensor batch_ptr_t = input_matrix.batch_pointers();
|
||||||
|
const int total_nnz = input_matrix.total_nnz();
|
||||||
|
|
||||||
|
Tensor output_row_ptr_t;
|
||||||
|
Tensor output_col_ind_t;
|
||||||
|
Tensor output_values_t;
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(ctx->allocate_temp(
|
||||||
|
DT_INT32, TensorShape({batch_size * (output_rows + 1)}),
|
||||||
|
&output_row_ptr_t));
|
||||||
|
TF_RETURN_IF_ERROR(ctx->allocate_temp(DT_INT32, TensorShape({total_nnz}),
|
||||||
|
&output_col_ind_t));
|
||||||
|
TF_RETURN_IF_ERROR(ctx->allocate_temp(
|
||||||
|
DataTypeToEnum<T>::value, TensorShape({total_nnz}), &output_values_t));
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(CSRSparseMatrix::CreateCSRSparseMatrix(
|
||||||
|
DataTypeToEnum<T>::value, output_dense_shape_t, batch_ptr_t,
|
||||||
|
output_row_ptr_t, output_col_ind_t, output_values_t, output_matrix));
|
||||||
|
|
||||||
|
// Set the output row pointers to zero, in case we hit any empty
|
||||||
|
// input batches.
|
||||||
|
functor::SetZeroFunctor<Device, int32> set_zero;
|
||||||
|
const Device& d = ctx->eigen_device<Device>();
|
||||||
|
set_zero(d, output_row_ptr_t.flat<int32>());
|
||||||
|
|
||||||
|
functor::CSRSparseMatrixTransposeComponent<Device, T> transpose_component;
|
||||||
|
for (int i = 0; i < batch_size; ++i) {
|
||||||
|
if (output_matrix->nnz(i) == 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
ConstCSRComponent<T> input_comp{
|
||||||
|
input_matrix.row_pointers_vec(i), input_matrix.col_indices_vec(i),
|
||||||
|
input_matrix.values_vec<T>(i), input_dense_shape};
|
||||||
|
CSRComponent<T> output_comp{
|
||||||
|
output_matrix->row_pointers_vec(i), output_matrix->col_indices_vec(i),
|
||||||
|
output_matrix->values_vec<T>(i), output_dense_shape};
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(transpose_component(ctx, input_comp, &output_comp));
|
||||||
|
}
|
||||||
|
if (conjugate) {
|
||||||
|
// conjugate all values with a single kernel launch.
|
||||||
|
maybe_conj_inplace<Device, T>::run(d, &output_values_t);
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifdef GOOGLE_CUDA
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct CSRSparseMatrixTransposeComponent<GPUDevice, T> {
|
||||||
|
Status operator()(OpKernelContext* ctx, const ConstCSRComponent<T>& x,
|
||||||
|
CSRComponent<T>* y) {
|
||||||
|
CudaSparse cuda_sparse(ctx);
|
||||||
|
TF_RETURN_IF_ERROR(cuda_sparse.Initialize());
|
||||||
|
const cusparseAction_t copyValues = CUSPARSE_ACTION_NUMERIC;
|
||||||
|
const int rank = x.dense_shape_host.size();
|
||||||
|
const int m = x.row_ptr.size() - 1;
|
||||||
|
const int n = x.dense_shape_host(rank - 1);
|
||||||
|
const int nnz = x.col_ind.size();
|
||||||
|
DCHECK_EQ(nnz, x.values.size());
|
||||||
|
DCHECK_EQ(n, y->row_ptr.size() - 1);
|
||||||
|
DCHECK_EQ(rank, y->dense_shape_host.size());
|
||||||
|
DCHECK_EQ(m, y->dense_shape_host(rank - 1));
|
||||||
|
DCHECK_EQ(nnz, y->col_ind.size());
|
||||||
|
DCHECK_EQ(nnz, y->values.size());
|
||||||
|
|
||||||
|
return cuda_sparse.Csr2csc(
|
||||||
|
m, n, nnz, x.values.data() /*csrVal*/, x.row_ptr.data() /*csrRowPtr*/,
|
||||||
|
x.col_ind.data() /*csrColInd*/, y->values.data() /*cscVal*/,
|
||||||
|
y->col_ind.data() /*cscRowInd*/, y->row_ptr.data() /*cscColPtr*/,
|
||||||
|
copyValues);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
#endif // GOOGLE_CUDA
|
||||||
|
} // namespace functor
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
73
tensorflow/core/kernels/sparse/transpose_op.h
Normal file
73
tensorflow/core/kernels/sparse/transpose_op.h
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_CORE_KERNELS_SPARSE_TRANSPOSE_OP_H_
|
||||||
|
#define TENSORFLOW_CORE_KERNELS_SPARSE_TRANSPOSE_OP_H_
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
|
#include "tensorflow/core/kernels/cwise_ops.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace functor {
|
||||||
|
|
||||||
|
template <typename Device, typename T>
|
||||||
|
struct maybe_conj_inplace {
|
||||||
|
static void run(const Device& d, Tensor* t) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Device>
|
||||||
|
struct maybe_conj_inplace<Device, complex64> {
|
||||||
|
static void run(const Device& d, Tensor* t) {
|
||||||
|
functor::UnaryFunctor<Device, functor::conj<complex64>> conj;
|
||||||
|
conj(d, t->flat<complex64>() /*out*/,
|
||||||
|
const_cast<const Tensor*>(t)->flat<complex64>() /*in*/);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Device>
|
||||||
|
struct maybe_conj_inplace<Device, complex128> {
|
||||||
|
static void run(const Device& d, Tensor* t) {
|
||||||
|
functor::UnaryFunctor<Device, functor::conj<complex128>> conj;
|
||||||
|
conj(d, t->flat<complex128>() /*out*/,
|
||||||
|
const_cast<const Tensor*>(t)->flat<complex128>() /*in*/);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Device, typename T>
|
||||||
|
struct maybe_conj {
|
||||||
|
static void run(const Device& d, const Tensor& in, Tensor* out) { *out = in; }
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Device>
|
||||||
|
struct maybe_conj<Device, complex64> {
|
||||||
|
static void run(const Device& d, const Tensor& in, Tensor* out) {
|
||||||
|
functor::UnaryFunctor<Device, functor::conj<complex64>> conj;
|
||||||
|
conj(d, out->flat<complex64>() /*out*/, in.flat<complex64>() /*in*/);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Device>
|
||||||
|
struct maybe_conj<Device, complex128> {
|
||||||
|
static void run(const Device& d, const Tensor& in, Tensor* out) {
|
||||||
|
functor::UnaryFunctor<Device, functor::conj<complex128>> conj;
|
||||||
|
conj(d, out->flat<complex128>() /*out*/, in.flat<complex128>() /*in*/);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace functor
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_CORE_KERNELS_SPARSE_TRANSPOSE_OP_H_
|
93
tensorflow/core/kernels/sparse/zeros_op.cc
Normal file
93
tensorflow/core/kernels/sparse/zeros_op.cc
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#define EIGEN_USE_THREADS
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
#define EIGEN_USE_GPU
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "tensorflow/core/kernels/sparse/zeros_op.h"
|
||||||
|
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
|
#include "tensorflow/core/framework/op.h"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
|
#include "tensorflow/core/framework/variant_op_registry.h"
|
||||||
|
#include "tensorflow/core/kernels/dense_update_functor.h"
|
||||||
|
#include "tensorflow/core/kernels/fill_functor.h"
|
||||||
|
#include "tensorflow/core/kernels/slice_op.h"
|
||||||
|
#include "tensorflow/core/kernels/sparse/kernels.h"
|
||||||
|
#include "tensorflow/core/kernels/sparse/sparse_matrix.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||||
|
typedef Eigen::GpuDevice GPUDevice;
|
||||||
|
|
||||||
|
template <typename Device>
|
||||||
|
class CSRZerosOp : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit CSRZerosOp(OpKernelConstruction* c) : OpKernel(c) {
|
||||||
|
OP_REQUIRES_OK(c, c->GetAttr("type", &dtype_));
|
||||||
|
}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* c) override {
|
||||||
|
const Tensor& dense_shape_t = c->input(0);
|
||||||
|
CSRSparseMatrix matrix;
|
||||||
|
functor::CSRSparseMatrixZeros<Device> csr_sparse_matrix_zeros;
|
||||||
|
OP_REQUIRES_OK(c,
|
||||||
|
csr_sparse_matrix_zeros(c, dtype_, dense_shape_t, &matrix));
|
||||||
|
Tensor* matrix_t;
|
||||||
|
AllocatorAttributes cpu_alloc;
|
||||||
|
cpu_alloc.set_on_host(true);
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
c, c->allocate_output(0, TensorShape({}), &matrix_t, cpu_alloc));
|
||||||
|
matrix_t->scalar<Variant>()() = matrix;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
DataType dtype_;
|
||||||
|
};
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
template <typename Device>
|
||||||
|
Status CSRSparseMatrixZerosLikeHelper(OpKernelContext* ctx,
|
||||||
|
const CSRSparseMatrix& x,
|
||||||
|
CSRSparseMatrix* y) {
|
||||||
|
functor::CSRSparseMatrixZeros<Device> csr_sparse_matrix_zeros;
|
||||||
|
return csr_sparse_matrix_zeros(ctx, x.dtype(), x.dense_shape(), y);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
#ifdef GOOGLE_CUDA
|
||||||
|
#define REGISTER(DEV) \
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("SparseMatrixZeros") \
|
||||||
|
.Device(DEVICE_##DEV) \
|
||||||
|
.HostMemory("dense_shape"), \
|
||||||
|
CSRZerosOp<DEV##Device>);
|
||||||
|
|
||||||
|
REGISTER(GPU)
|
||||||
|
|
||||||
|
REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(
|
||||||
|
ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU, CSRSparseMatrix,
|
||||||
|
CSRSparseMatrixZerosLikeHelper<GPUDevice>);
|
||||||
|
|
||||||
|
#undef REGISTER
|
||||||
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
88
tensorflow/core/kernels/sparse/zeros_op.h
Normal file
88
tensorflow/core/kernels/sparse/zeros_op.h
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_CORE_KERNELS_SPARSE_ZEROS_OP_H_
|
||||||
|
#define TENSORFLOW_CORE_KERNELS_SPARSE_ZEROS_OP_H_
|
||||||
|
|
||||||
|
#define EIGEN_USE_THREADS
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
#define EIGEN_USE_GPU
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
|
#include "tensorflow/core/framework/op.h"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
|
#include "tensorflow/core/framework/variant_op_registry.h"
|
||||||
|
#include "tensorflow/core/kernels/dense_update_functor.h"
|
||||||
|
#include "tensorflow/core/kernels/fill_functor.h"
|
||||||
|
#include "tensorflow/core/kernels/sparse/sparse_matrix.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||||
|
typedef Eigen::GpuDevice GPUDevice;
|
||||||
|
|
||||||
|
namespace functor {
|
||||||
|
|
||||||
|
template <typename Device>
|
||||||
|
struct CSRSparseMatrixZeros {
|
||||||
|
Status operator()(OpKernelContext* c, DataType dtype,
|
||||||
|
const Tensor& dense_shape_t, CSRSparseMatrix* matrix) {
|
||||||
|
auto dense_shape = dense_shape_t.vec<int64>();
|
||||||
|
const int rank = dense_shape.size();
|
||||||
|
if (!(rank == 2 || rank == 3)) {
|
||||||
|
return errors::InvalidArgument("sparse tensor must have rank == 2 or 3; ",
|
||||||
|
"but dense shape has ", rank, " entries");
|
||||||
|
}
|
||||||
|
const int64 batch_size = (rank == 2) ? 1 : dense_shape(0);
|
||||||
|
const int64 rows = dense_shape((rank == 2) ? 0 : 1);
|
||||||
|
|
||||||
|
Tensor batch_ptr_t(cpu_allocator(), DT_INT32,
|
||||||
|
TensorShape({batch_size + 1}));
|
||||||
|
batch_ptr_t.vec<int32>().setZero(); // On host.
|
||||||
|
|
||||||
|
Allocator* allocator = c->device()->GetAllocator(AllocatorAttributes());
|
||||||
|
// An all-zeros CSR matrix is composed of an empty set of column
|
||||||
|
// indices, an empty set of values, and a vector of all zero row
|
||||||
|
// pointers. The length of the row pointers vector is #rows + 1.
|
||||||
|
// Each row pointer is just an offset into the cols and
|
||||||
|
// values vectors, and those are empty, all coefficients are zero.
|
||||||
|
Tensor csr_row_ptr_t;
|
||||||
|
Tensor coo_col_ind_t(allocator, DT_INT32, TensorShape({0}));
|
||||||
|
Tensor csr_values_t(allocator, dtype, TensorShape({0}));
|
||||||
|
const Device& d = c->eigen_device<Device>();
|
||||||
|
functor::SetZeroFunctor<Device, int32> set_zero;
|
||||||
|
Tensor* csr_row_ptr_t_ptr;
|
||||||
|
PersistentTensor csr_row_ptr_pt;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
c->allocate_persistent(DT_INT32, TensorShape({batch_size * (rows + 1)}),
|
||||||
|
&csr_row_ptr_pt, &csr_row_ptr_t_ptr));
|
||||||
|
set_zero(d, csr_row_ptr_t_ptr->flat<int32>());
|
||||||
|
csr_row_ptr_t = std::move(*csr_row_ptr_t_ptr);
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(CSRSparseMatrix::CreateCSRSparseMatrix(
|
||||||
|
dtype, dense_shape_t, batch_ptr_t, csr_row_ptr_t, coo_col_ind_t,
|
||||||
|
csr_values_t, matrix));
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace functor
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_CORE_KERNELS_SPARSE_ZEROS_OP_H_
|
613
tensorflow/core/ops/sparse_csr_matrix_ops.cc
Normal file
613
tensorflow/core/ops/sparse_csr_matrix_ops.cc
Normal file
@ -0,0 +1,613 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||||
|
#include "tensorflow/core/framework/numeric_op.h"
|
||||||
|
#include "tensorflow/core/framework/op.h"
|
||||||
|
#include "tensorflow/core/framework/shape_inference.h"
|
||||||
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
using shape_inference::DimensionHandle;
|
||||||
|
using shape_inference::InferenceContext;
|
||||||
|
using shape_inference::ShapeAndType;
|
||||||
|
using shape_inference::ShapeHandle;
|
||||||
|
|
||||||
|
Status GetVariantInput(InferenceContext* c, int index,
|
||||||
|
ShapeAndType* shape_and_type) {
|
||||||
|
ShapeHandle variant;
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(index), 0, &variant));
|
||||||
|
auto* shapes_and_types = c->input_handle_shapes_and_types(index);
|
||||||
|
if (shapes_and_types == nullptr || shapes_and_types->size() != 1) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Unable to access shape and type info from variant input ", index);
|
||||||
|
}
|
||||||
|
*shape_and_type = shapes_and_types->at(0);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validates that a shape represents a (rank-2) square matrix or a (rank-3)
|
||||||
|
// batch of square matrices.
|
||||||
|
Status ValidateSquareMatrixShape(InferenceContext* c,
|
||||||
|
const ShapeHandle& matrix_shape,
|
||||||
|
DimensionHandle* matrix_dimension) {
|
||||||
|
ShapeHandle out;
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRankAtLeast(matrix_shape, 2, &out));
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRankAtMost(matrix_shape, 3, &out));
|
||||||
|
if (!c->RankKnown(matrix_shape)) {
|
||||||
|
return errors::Internal("Sparse matrix has an unknown rank.");
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(c->Merge(c->Dim(matrix_shape, -2),
|
||||||
|
c->Dim(matrix_shape, -1), matrix_dimension));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
REGISTER_OP("SparseTensorToCSRSparseMatrix")
|
||||||
|
.Input("indices: int64")
|
||||||
|
.Input("values: T")
|
||||||
|
.Input("dense_shape: int64")
|
||||||
|
.Attr("T: {float, double, complex64, complex128}")
|
||||||
|
.Output("sparse_matrix: variant")
|
||||||
|
.SetShapeFn([](InferenceContext* c) {
|
||||||
|
TF_RETURN_IF_ERROR(shape_inference::ValidateSparseTensor(
|
||||||
|
c, c->input(0), c->input(1), c->input(2)));
|
||||||
|
auto rank = c->Value(c->Dim(c->input(0), 1));
|
||||||
|
ShapeHandle dense_shape;
|
||||||
|
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &dense_shape));
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRank(dense_shape, rank, &dense_shape));
|
||||||
|
if (!c->RankKnown(dense_shape) || c->Rank(dense_shape) < 2 ||
|
||||||
|
c->Rank(dense_shape) > 3) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Invalid rank: ", c->Rank(dense_shape),
|
||||||
|
". Expected a known rank of either 2 or 3.");
|
||||||
|
}
|
||||||
|
|
||||||
|
DataType dtype;
|
||||||
|
TF_RETURN_IF_ERROR(c->GetAttr("T", &dtype));
|
||||||
|
c->set_output(0, c->Scalar());
|
||||||
|
c->set_output_handle_shapes_and_types(0,
|
||||||
|
{ShapeAndType{dense_shape, dtype}});
|
||||||
|
return Status::OK();
|
||||||
|
});
|
||||||
|
|
||||||
|
REGISTER_OP("CSRSparseMatrixToSparseTensor")
|
||||||
|
.Input("sparse_matrix: variant")
|
||||||
|
.Output("indices: int64")
|
||||||
|
.Output("values: type")
|
||||||
|
.Output("dense_shape: int64")
|
||||||
|
.Attr("type: {float, double, complex64, complex128}")
|
||||||
|
.SetShapeFn([](InferenceContext* c) {
|
||||||
|
ShapeAndType sparse_matrix_shape_and_type;
|
||||||
|
TF_RETURN_IF_ERROR(GetVariantInput(c, 0, &sparse_matrix_shape_and_type));
|
||||||
|
ShapeHandle sparse_matrix = sparse_matrix_shape_and_type.shape;
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRankAtMost(sparse_matrix, 3, &sparse_matrix));
|
||||||
|
if (!c->RankKnown(sparse_matrix)) {
|
||||||
|
return errors::InvalidArgument("sparse_matrix has an unknown rank.");
|
||||||
|
}
|
||||||
|
int rank = c->Rank(sparse_matrix);
|
||||||
|
ShapeHandle indices = c->Matrix(c->UnknownDim(), rank);
|
||||||
|
ShapeHandle values = c->Vector(c->UnknownDim());
|
||||||
|
ShapeHandle dense_shape = c->Vector(rank);
|
||||||
|
c->set_output(0, indices);
|
||||||
|
c->set_output(1, values);
|
||||||
|
c->set_output(2, dense_shape);
|
||||||
|
return Status::OK();
|
||||||
|
});
|
||||||
|
|
||||||
|
REGISTER_OP("DenseToCSRSparseMatrix")
|
||||||
|
.Input("dense_input: T")
|
||||||
|
.Input("indices: int64")
|
||||||
|
.Attr("T: {float, double, complex64, complex128}")
|
||||||
|
.Output("sparse_output: variant")
|
||||||
|
.SetShapeFn([](InferenceContext* c) {
|
||||||
|
ShapeHandle dense_shape = c->input(0);
|
||||||
|
if (!c->RankKnown(dense_shape) || c->Rank(dense_shape) < 2 ||
|
||||||
|
c->Rank(dense_shape) > 3) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Invalid rank of dense: ", c->Rank(dense_shape),
|
||||||
|
". Expected a known rank of either 2 or 3.");
|
||||||
|
}
|
||||||
|
auto rank = c->Rank(dense_shape);
|
||||||
|
|
||||||
|
ShapeHandle indices = c->input(1);
|
||||||
|
if (!c->RankKnown(indices) || c->Rank(indices) != 2) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"indices must be a matrix; but its rank is not 2: ",
|
||||||
|
c->Rank(indices));
|
||||||
|
}
|
||||||
|
auto indices_col = c->Dim(indices, 1);
|
||||||
|
if (!c->ValueKnown(indices_col) || c->Value(indices_col) != rank) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"indices.shape[1] must match rank of dense; saw: ",
|
||||||
|
c->Value(indices_col), " vs. ", rank);
|
||||||
|
}
|
||||||
|
ShapeHandle fake_values_vec = c->Vector(c->Dim(indices, 0));
|
||||||
|
ShapeHandle fake_shape_shape = c->Vector(rank);
|
||||||
|
TF_RETURN_IF_ERROR(shape_inference::ValidateSparseTensor(
|
||||||
|
c, indices /*indices_shape*/, fake_values_vec /*values_shape*/,
|
||||||
|
fake_shape_shape /*shape_shape*/));
|
||||||
|
DataType dtype;
|
||||||
|
TF_RETURN_IF_ERROR(c->GetAttr("T", &dtype));
|
||||||
|
c->set_output_handle_shapes_and_types(0,
|
||||||
|
{ShapeAndType{dense_shape, dtype}});
|
||||||
|
c->set_output(0, c->Scalar());
|
||||||
|
return Status::OK();
|
||||||
|
});
|
||||||
|
|
||||||
|
REGISTER_OP("CSRSparseMatrixToDense")
|
||||||
|
.Input("sparse_input: variant")
|
||||||
|
.Output("dense_output: type")
|
||||||
|
.Attr("type: {float, double, complex64, complex128}")
|
||||||
|
.SetShapeFn([](InferenceContext* c) {
|
||||||
|
ShapeAndType sparse_matrix_shape_and_type;
|
||||||
|
TF_RETURN_IF_ERROR(GetVariantInput(c, 0, &sparse_matrix_shape_and_type));
|
||||||
|
ShapeHandle sparse_matrix = sparse_matrix_shape_and_type.shape;
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRankAtMost(sparse_matrix, 3, &sparse_matrix));
|
||||||
|
if (!c->RankKnown(sparse_matrix)) {
|
||||||
|
return errors::InvalidArgument("sparse_matrix has an unknown rank.");
|
||||||
|
}
|
||||||
|
c->set_output(0, sparse_matrix);
|
||||||
|
return Status::OK();
|
||||||
|
});
|
||||||
|
|
||||||
|
REGISTER_OP("CSRSparseMatrixComponents")
|
||||||
|
.Input("csr_sparse_matrix: variant")
|
||||||
|
.Input("index: int32")
|
||||||
|
.Output("row_ptrs: int32")
|
||||||
|
.Output("col_inds: int32")
|
||||||
|
.Output("values: type")
|
||||||
|
.Attr("type: {float, double, complex64, complex128}")
|
||||||
|
.SetShapeFn([](InferenceContext* c) {
|
||||||
|
ShapeAndType sparse_matrix_shape_and_type;
|
||||||
|
TF_RETURN_IF_ERROR(GetVariantInput(c, 0, &sparse_matrix_shape_and_type));
|
||||||
|
ShapeHandle csr_sparse_matrix = sparse_matrix_shape_and_type.shape;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
c->WithRankAtLeast(csr_sparse_matrix, 2, &csr_sparse_matrix));
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
c->WithRankAtMost(csr_sparse_matrix, 3, &csr_sparse_matrix));
|
||||||
|
ShapeHandle index;
|
||||||
|
if (c->Rank(c->input(1)) != 0) {
|
||||||
|
return errors::InvalidArgument("index must be a scalar.");
|
||||||
|
}
|
||||||
|
if (!c->RankKnown(csr_sparse_matrix)) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"csr_sparse_matrix has an unknown rank.");
|
||||||
|
}
|
||||||
|
auto row_ptrs_dh = c->Dim(csr_sparse_matrix, -2);
|
||||||
|
TF_RETURN_IF_ERROR(c->Add(row_ptrs_dh, 1, &row_ptrs_dh));
|
||||||
|
ShapeHandle row_ptrs = c->Vector(row_ptrs_dh);
|
||||||
|
c->set_output(0, row_ptrs);
|
||||||
|
c->set_output(1, c->Vector(c->UnknownDim()));
|
||||||
|
c->set_output(2, c->Vector(c->UnknownDim()));
|
||||||
|
return Status::OK();
|
||||||
|
});
|
||||||
|
|
||||||
|
REGISTER_OP("SparseMatrixNNZ")
|
||||||
|
.Input("sparse_matrix: variant")
|
||||||
|
.Output("nnz: int32")
|
||||||
|
.SetShapeFn([](InferenceContext* c) {
|
||||||
|
ShapeAndType sparse_matrix_shape_and_type;
|
||||||
|
TF_RETURN_IF_ERROR(GetVariantInput(c, 0, &sparse_matrix_shape_and_type));
|
||||||
|
ShapeHandle sparse_matrix = sparse_matrix_shape_and_type.shape;
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRankAtLeast(sparse_matrix, 2, &sparse_matrix));
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRankAtMost(sparse_matrix, 3, &sparse_matrix));
|
||||||
|
if (!c->RankKnown(sparse_matrix)) {
|
||||||
|
return errors::InvalidArgument("sparse_matrix has an unknown rank.");
|
||||||
|
}
|
||||||
|
ShapeHandle out;
|
||||||
|
if (c->Rank(sparse_matrix) == 3) {
|
||||||
|
out = c->Vector(c->Dim(sparse_matrix, 0));
|
||||||
|
} else {
|
||||||
|
out = c->Scalar();
|
||||||
|
}
|
||||||
|
c->set_output(0, out);
|
||||||
|
return Status::OK();
|
||||||
|
});
|
||||||
|
|
||||||
|
REGISTER_OP("SparseMatrixMatMul")
|
||||||
|
.Input("a: variant")
|
||||||
|
.Input("b: T")
|
||||||
|
.Attr("T: type")
|
||||||
|
.Attr("transpose_a: bool = false")
|
||||||
|
.Attr("transpose_b: bool = false")
|
||||||
|
.Attr("adjoint_a: bool = false")
|
||||||
|
.Attr("adjoint_b: bool = false")
|
||||||
|
.Attr("transpose_output: bool = false")
|
||||||
|
.Attr("conjugate_output: bool = false")
|
||||||
|
.Output("output: T")
|
||||||
|
.SetShapeFn([](InferenceContext* c) {
|
||||||
|
ShapeAndType sparse_matrix_shape_and_type;
|
||||||
|
TF_RETURN_IF_ERROR(GetVariantInput(c, 0, &sparse_matrix_shape_and_type));
|
||||||
|
ShapeHandle a_shape = sparse_matrix_shape_and_type.shape;
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRankAtLeast(a_shape, 2, &a_shape));
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRankAtMost(a_shape, 3, &a_shape));
|
||||||
|
if (!c->RankKnown(a_shape)) {
|
||||||
|
return errors::Internal("a has an unknown rank.");
|
||||||
|
}
|
||||||
|
ShapeHandle b_shape;
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &b_shape));
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRankAtMost(b_shape, 3, &b_shape));
|
||||||
|
|
||||||
|
bool transpose_a = false;
|
||||||
|
bool transpose_b = false;
|
||||||
|
bool transpose_output = false;
|
||||||
|
|
||||||
|
// TODO(ebrevdo): Add transpose support.
|
||||||
|
TF_RETURN_IF_ERROR(c->GetAttr("transpose_a", &transpose_a));
|
||||||
|
TF_RETURN_IF_ERROR(c->GetAttr("transpose_b", &transpose_b));
|
||||||
|
TF_RETURN_IF_ERROR(c->GetAttr("transpose_output", &transpose_output));
|
||||||
|
|
||||||
|
bool adjoint_a = false;
|
||||||
|
bool adjoint_b = false;
|
||||||
|
TF_RETURN_IF_ERROR(c->GetAttr("adjoint_a", &adjoint_a));
|
||||||
|
TF_RETURN_IF_ERROR(c->GetAttr("adjoint_b", &adjoint_b));
|
||||||
|
if (adjoint_a && transpose_a) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Only one of adjoint_a and transpose_a may be true.");
|
||||||
|
}
|
||||||
|
if (adjoint_b && transpose_b) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Only one of adjoint_b and transpose_b may be true.");
|
||||||
|
}
|
||||||
|
transpose_a = transpose_a || adjoint_a;
|
||||||
|
transpose_b = transpose_b || adjoint_b;
|
||||||
|
|
||||||
|
auto output_rows = c->Dim(a_shape, transpose_a ? -1 : -2);
|
||||||
|
auto output_cols = c->Dim(b_shape, transpose_b ? -2 : -1);
|
||||||
|
if (transpose_output) {
|
||||||
|
std::tie(output_rows, output_cols) =
|
||||||
|
std::make_tuple(output_cols, output_rows);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Batch dims match between inputs.
|
||||||
|
ShapeHandle a_batch_dims;
|
||||||
|
ShapeHandle b_batch_dims;
|
||||||
|
ShapeHandle batch_dims;
|
||||||
|
TF_RETURN_IF_ERROR(c->Subshape(a_shape, 0, -2, &a_batch_dims));
|
||||||
|
TF_RETURN_IF_ERROR(c->Subshape(b_shape, 0, -2, &b_batch_dims));
|
||||||
|
TF_RETURN_IF_ERROR(c->Merge(a_batch_dims, b_batch_dims, &batch_dims));
|
||||||
|
|
||||||
|
// Assert inner dims match.
|
||||||
|
shape_inference::DimensionHandle unused;
|
||||||
|
TF_RETURN_IF_ERROR(c->Merge(c->Dim(a_shape, transpose_a ? -2 : -1),
|
||||||
|
c->Dim(b_shape, transpose_b ? -1 : -2),
|
||||||
|
&unused));
|
||||||
|
|
||||||
|
ShapeHandle out;
|
||||||
|
TF_RETURN_IF_ERROR(c->Concatenate(
|
||||||
|
batch_dims, c->Matrix(output_rows, output_cols), &out));
|
||||||
|
|
||||||
|
c->set_output(0, out);
|
||||||
|
return Status::OK();
|
||||||
|
});
|
||||||
|
|
||||||
|
REGISTER_OP("SparseMatrixMul")
|
||||||
|
.Input("a: variant")
|
||||||
|
.Input("b: T")
|
||||||
|
.Attr("T: type")
|
||||||
|
.Output("output: variant")
|
||||||
|
.SetShapeFn([](InferenceContext* c) {
|
||||||
|
ShapeAndType sparse_matrix_shape_and_type;
|
||||||
|
TF_RETURN_IF_ERROR(GetVariantInput(c, 0, &sparse_matrix_shape_and_type));
|
||||||
|
ShapeHandle a_shape = sparse_matrix_shape_and_type.shape;
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRankAtMost(a_shape, 3, &a_shape));
|
||||||
|
if (!c->RankKnown(a_shape)) {
|
||||||
|
return errors::Internal("a has an unknown rank.");
|
||||||
|
}
|
||||||
|
ShapeHandle b_shape;
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 3, &b_shape));
|
||||||
|
if (!c->RankKnown(b_shape)) {
|
||||||
|
return errors::Internal("b has an unknown rank.");
|
||||||
|
}
|
||||||
|
ShapeHandle out;
|
||||||
|
if (c->Rank(b_shape) == 0) {
|
||||||
|
out = a_shape;
|
||||||
|
} else if (c->Rank(b_shape) == 3) {
|
||||||
|
if (c->Rank(a_shape) != 3) {
|
||||||
|
return errors::Unimplemented("rank of b is 3 but rank of a is not.");
|
||||||
|
}
|
||||||
|
if (!(c->Value(c->Dim(b_shape, 1)) == 1 &&
|
||||||
|
c->Value(c->Dim(b_shape, 2)) == 1)) {
|
||||||
|
return errors::Unimplemented(
|
||||||
|
"b must be a scalar or shaped [batch_size, 1, 1]");
|
||||||
|
}
|
||||||
|
DimensionHandle batch_size = c->Dim(a_shape, 0);
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
c->Merge(batch_size, c->Dim(b_shape, 0), &batch_size));
|
||||||
|
TF_RETURN_IF_ERROR(c->ReplaceDim(b_shape, 0, batch_size, &b_shape));
|
||||||
|
TF_RETURN_IF_ERROR(c->ReplaceDim(a_shape, 0, batch_size, &a_shape));
|
||||||
|
out = a_shape;
|
||||||
|
} else {
|
||||||
|
return errors::Unimplemented(
|
||||||
|
"b must be a scalar or shaped [batch_size, 1, 1]");
|
||||||
|
}
|
||||||
|
c->set_output_handle_shapes_and_types(
|
||||||
|
0, {ShapeAndType{out, sparse_matrix_shape_and_type.dtype}});
|
||||||
|
c->set_output(0, c->Scalar());
|
||||||
|
return Status::OK();
|
||||||
|
});
|
||||||
|
|
||||||
|
REGISTER_OP("SparseMatrixAdd")
|
||||||
|
.Input("a: variant")
|
||||||
|
.Input("b: variant")
|
||||||
|
.Input("alpha: T")
|
||||||
|
.Input("beta: T")
|
||||||
|
.Attr("T: {float, double, complex64, complex128}")
|
||||||
|
.Output("c: variant")
|
||||||
|
.SetShapeFn([](InferenceContext* c) {
|
||||||
|
// alpha and beta are scalars.
|
||||||
|
ShapeHandle unused_scalar_shape;
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_scalar_shape));
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused_scalar_shape));
|
||||||
|
|
||||||
|
ShapeAndType sparse_matrix_shape_and_type;
|
||||||
|
TF_RETURN_IF_ERROR(GetVariantInput(c, 0, &sparse_matrix_shape_and_type));
|
||||||
|
ShapeHandle a_shape = sparse_matrix_shape_and_type.shape;
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRankAtLeast(a_shape, 2, &a_shape));
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRankAtMost(a_shape, 3, &a_shape));
|
||||||
|
if (!c->RankKnown(a_shape)) {
|
||||||
|
return errors::InvalidArgument("a has an unknown rank.");
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(GetVariantInput(c, 1, &sparse_matrix_shape_and_type));
|
||||||
|
ShapeHandle b_shape = sparse_matrix_shape_and_type.shape;
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRankAtLeast(b_shape, 2, &b_shape));
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRankAtMost(b_shape, 3, &b_shape));
|
||||||
|
if (!c->RankKnown(b_shape)) {
|
||||||
|
return errors::InvalidArgument("b has an unknown rank.");
|
||||||
|
}
|
||||||
|
ShapeHandle out;
|
||||||
|
TF_RETURN_IF_ERROR(c->Merge(a_shape, b_shape, &out));
|
||||||
|
c->set_output_handle_shapes_and_types(
|
||||||
|
0, {ShapeAndType{out, sparse_matrix_shape_and_type.dtype}});
|
||||||
|
c->set_output(0, c->Scalar());
|
||||||
|
return Status::OK();
|
||||||
|
});
|
||||||
|
|
||||||
|
REGISTER_OP("SparseMatrixSparseMatMul")
|
||||||
|
.Input("a: variant")
|
||||||
|
.Input("b: variant")
|
||||||
|
.Attr("type: {float, double, complex64, complex128}")
|
||||||
|
.Attr("transpose_a: bool = false")
|
||||||
|
.Attr("transpose_b: bool = false")
|
||||||
|
.Attr("adjoint_a: bool = false")
|
||||||
|
.Attr("adjoint_b: bool = false")
|
||||||
|
.Output("c: variant")
|
||||||
|
.SetShapeFn([](InferenceContext* c) {
|
||||||
|
ShapeAndType sparse_matrix_shape_and_type;
|
||||||
|
TF_RETURN_IF_ERROR(GetVariantInput(c, 0, &sparse_matrix_shape_and_type));
|
||||||
|
ShapeHandle a_shape = sparse_matrix_shape_and_type.shape;
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRankAtLeast(a_shape, 2, &a_shape));
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRankAtMost(a_shape, 3, &a_shape));
|
||||||
|
if (!c->RankKnown(a_shape)) {
|
||||||
|
return errors::Internal("a has an unknown rank.");
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(GetVariantInput(c, 1, &sparse_matrix_shape_and_type));
|
||||||
|
ShapeHandle b_shape = sparse_matrix_shape_and_type.shape;
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRankAtLeast(b_shape, 2, &b_shape));
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRankAtMost(b_shape, 3, &b_shape));
|
||||||
|
if (!c->RankKnown(b_shape)) {
|
||||||
|
return errors::Internal("b has an unknown rank.");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool transpose_a = false;
|
||||||
|
bool transpose_b = false;
|
||||||
|
TF_RETURN_IF_ERROR(c->GetAttr("transpose_a", &transpose_a));
|
||||||
|
TF_RETURN_IF_ERROR(c->GetAttr("transpose_b", &transpose_b));
|
||||||
|
bool adjoint_a = false;
|
||||||
|
bool adjoint_b = false;
|
||||||
|
TF_RETURN_IF_ERROR(c->GetAttr("adjoint_a", &adjoint_a));
|
||||||
|
TF_RETURN_IF_ERROR(c->GetAttr("adjoint_b", &adjoint_b));
|
||||||
|
if (adjoint_a && transpose_a) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Only one of adjoint_a and transpose_a may be true.");
|
||||||
|
} else if (adjoint_b && transpose_b) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Only one of adjoint_b and transpose_b may be true.");
|
||||||
|
}
|
||||||
|
transpose_a = transpose_a || adjoint_a;
|
||||||
|
transpose_b = transpose_b || adjoint_b;
|
||||||
|
|
||||||
|
auto output_rows = c->Dim(a_shape, transpose_a ? -1 : -2);
|
||||||
|
auto output_cols = c->Dim(b_shape, transpose_b ? -2 : -1);
|
||||||
|
|
||||||
|
// Batch dims match between inputs.
|
||||||
|
ShapeHandle a_batch_dims;
|
||||||
|
ShapeHandle b_batch_dims;
|
||||||
|
ShapeHandle batch_dims;
|
||||||
|
TF_RETURN_IF_ERROR(c->Subshape(a_shape, 0, -2, &a_batch_dims));
|
||||||
|
TF_RETURN_IF_ERROR(c->Subshape(b_shape, 0, -2, &b_batch_dims));
|
||||||
|
TF_RETURN_IF_ERROR(c->Merge(a_batch_dims, b_batch_dims, &batch_dims));
|
||||||
|
|
||||||
|
// Assert inner dims match.
|
||||||
|
shape_inference::DimensionHandle unused;
|
||||||
|
TF_RETURN_IF_ERROR(c->Merge(c->Dim(a_shape, transpose_a ? -2 : -1),
|
||||||
|
c->Dim(b_shape, transpose_b ? -1 : -2),
|
||||||
|
&unused));
|
||||||
|
|
||||||
|
ShapeHandle out;
|
||||||
|
TF_RETURN_IF_ERROR(c->Concatenate(
|
||||||
|
batch_dims, c->Matrix(output_rows, output_cols), &out));
|
||||||
|
|
||||||
|
c->set_output_handle_shapes_and_types(
|
||||||
|
0, {ShapeAndType{out, sparse_matrix_shape_and_type.dtype}});
|
||||||
|
c->set_output(0, c->Scalar());
|
||||||
|
return Status::OK();
|
||||||
|
});
|
||||||
|
|
||||||
|
REGISTER_OP("SparseMatrixZeros")
|
||||||
|
.Input("dense_shape: int64")
|
||||||
|
.Attr("type: {float, double, complex64, complex128}")
|
||||||
|
.Output("sparse_matrix: variant")
|
||||||
|
.SetShapeFn([](InferenceContext* c) {
|
||||||
|
auto rank = c->NumElements(c->input(0));
|
||||||
|
ShapeHandle dense_shape;
|
||||||
|
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &dense_shape));
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
c->WithRank(dense_shape, c->Value(rank), &dense_shape));
|
||||||
|
if (!c->RankKnown(dense_shape) || c->Rank(dense_shape) < 2 ||
|
||||||
|
c->Rank(dense_shape) > 3) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Invalid rank: ", c->Rank(dense_shape),
|
||||||
|
". Expected a known rank of either 2 or 3.");
|
||||||
|
}
|
||||||
|
DataType dtype;
|
||||||
|
TF_RETURN_IF_ERROR(c->GetAttr("type", &dtype));
|
||||||
|
c->set_output_handle_shapes_and_types(0,
|
||||||
|
{ShapeAndType{dense_shape, dtype}});
|
||||||
|
c->set_output(0, c->Scalar());
|
||||||
|
return Status::OK();
|
||||||
|
});
|
||||||
|
|
||||||
|
REGISTER_OP("SparseMatrixTranspose")
|
||||||
|
.Input("input: variant")
|
||||||
|
.Attr("conjugate: bool = false")
|
||||||
|
.Attr("type: {float, double, complex64, complex128}")
|
||||||
|
.Output("output: variant")
|
||||||
|
.SetShapeFn([](InferenceContext* c) {
|
||||||
|
ShapeAndType sparse_matrix_shape_and_type;
|
||||||
|
TF_RETURN_IF_ERROR(GetVariantInput(c, 0, &sparse_matrix_shape_and_type));
|
||||||
|
ShapeHandle input = sparse_matrix_shape_and_type.shape;
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, 2, &input));
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRankAtMost(input, 3, &input));
|
||||||
|
if (!c->RankKnown(input)) {
|
||||||
|
return errors::InvalidArgument("input has an unknown rank.");
|
||||||
|
}
|
||||||
|
ShapeHandle output;
|
||||||
|
if (c->Rank(input) == 2) {
|
||||||
|
output = c->Matrix(c->Dim(input, 1), c->Dim(input, 0));
|
||||||
|
} else {
|
||||||
|
output = c->MakeShape(
|
||||||
|
{c->Dim(input, 0), c->Dim(input, 2), c->Dim(input, 1)});
|
||||||
|
}
|
||||||
|
c->set_output_handle_shapes_and_types(
|
||||||
|
0, {ShapeAndType{output, sparse_matrix_shape_and_type.dtype}});
|
||||||
|
c->set_output(0, c->Scalar());
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
});
|
||||||
|
|
||||||
|
REGISTER_OP("SparseMatrixSoftmax")
|
||||||
|
.Input("logits: variant")
|
||||||
|
.Attr("type: {float, double}")
|
||||||
|
.Output("softmax: variant")
|
||||||
|
.SetShapeFn([](InferenceContext* c) {
|
||||||
|
ShapeAndType sparse_matrix_shape_and_type;
|
||||||
|
TF_RETURN_IF_ERROR(GetVariantInput(c, 0, &sparse_matrix_shape_and_type));
|
||||||
|
ShapeHandle logits = sparse_matrix_shape_and_type.shape;
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRankAtLeast(logits, 2, &logits));
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRankAtMost(logits, 3, &logits));
|
||||||
|
if (!c->RankKnown(logits)) {
|
||||||
|
return errors::InvalidArgument("logits has an unknown rank.");
|
||||||
|
}
|
||||||
|
c->set_output_handle_shapes_and_types(
|
||||||
|
0, {ShapeAndType{logits, sparse_matrix_shape_and_type.dtype}});
|
||||||
|
c->set_output(0, c->Scalar());
|
||||||
|
return Status::OK();
|
||||||
|
});
|
||||||
|
|
||||||
|
REGISTER_OP("SparseMatrixSoftmaxGrad")
|
||||||
|
.Input("softmax: variant")
|
||||||
|
.Input("grad_softmax: variant")
|
||||||
|
.Attr("type: {float, double}")
|
||||||
|
.Output("gradient: variant")
|
||||||
|
.SetShapeFn([](InferenceContext* c) {
|
||||||
|
ShapeAndType sparse_matrix_shape_and_type;
|
||||||
|
TF_RETURN_IF_ERROR(GetVariantInput(c, 0, &sparse_matrix_shape_and_type));
|
||||||
|
ShapeHandle softmax = sparse_matrix_shape_and_type.shape;
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRankAtLeast(softmax, 2, &softmax));
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRankAtMost(softmax, 3, &softmax));
|
||||||
|
if (!c->RankKnown(softmax)) {
|
||||||
|
return errors::InvalidArgument("softmax has an unknown rank.");
|
||||||
|
}
|
||||||
|
TF_RETURN_IF_ERROR(GetVariantInput(c, 1, &sparse_matrix_shape_and_type));
|
||||||
|
ShapeHandle grad_softmax = sparse_matrix_shape_and_type.shape;
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRankAtLeast(grad_softmax, 2, &grad_softmax));
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRankAtMost(grad_softmax, 3, &grad_softmax));
|
||||||
|
if (!c->RankKnown(grad_softmax)) {
|
||||||
|
return errors::InvalidArgument("grad_softmax has an unknown rank.");
|
||||||
|
}
|
||||||
|
TF_RETURN_IF_ERROR(c->Merge(softmax, grad_softmax, &softmax));
|
||||||
|
c->set_output_handle_shapes_and_types(
|
||||||
|
0, {ShapeAndType{softmax, sparse_matrix_shape_and_type.dtype}});
|
||||||
|
c->set_output(0, c->Scalar());
|
||||||
|
return Status::OK();
|
||||||
|
});
|
||||||
|
|
||||||
|
REGISTER_OP("SparseMatrixOrderingAMD")
|
||||||
|
.Input("input: variant")
|
||||||
|
.Output("output: int32")
|
||||||
|
.SetShapeFn([](InferenceContext* c) {
|
||||||
|
ShapeAndType sparse_matrix_shape_and_type;
|
||||||
|
TF_RETURN_IF_ERROR(GetVariantInput(c, 0, &sparse_matrix_shape_and_type));
|
||||||
|
ShapeHandle matrix_shape = sparse_matrix_shape_and_type.shape;
|
||||||
|
DimensionHandle n;
|
||||||
|
TF_RETURN_IF_ERROR(ValidateSquareMatrixShape(c, matrix_shape, &n));
|
||||||
|
|
||||||
|
ShapeHandle output;
|
||||||
|
if (c->Rank(matrix_shape) == 2) {
|
||||||
|
output = c->Vector(c->Dim(matrix_shape, 0));
|
||||||
|
} else {
|
||||||
|
output = c->Matrix(c->Dim(matrix_shape, 0), c->Dim(matrix_shape, 1));
|
||||||
|
}
|
||||||
|
c->set_output(0, output);
|
||||||
|
return Status::OK();
|
||||||
|
});
|
||||||
|
|
||||||
|
REGISTER_OP("SparseMatrixSparseCholesky")
|
||||||
|
.Input("input: variant")
|
||||||
|
.Input("permutation: int32")
|
||||||
|
.Attr("type: {float, double, complex64, complex128}")
|
||||||
|
.Output("output: variant")
|
||||||
|
.SetShapeFn([](InferenceContext* c) {
|
||||||
|
ShapeAndType sparse_matrix_shape_and_type;
|
||||||
|
TF_RETURN_IF_ERROR(GetVariantInput(c, 0, &sparse_matrix_shape_and_type));
|
||||||
|
ShapeHandle matrix_shape = sparse_matrix_shape_and_type.shape;
|
||||||
|
DimensionHandle n;
|
||||||
|
TF_RETURN_IF_ERROR(ValidateSquareMatrixShape(c, matrix_shape, &n));
|
||||||
|
|
||||||
|
ShapeHandle perm_shape;
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &perm_shape));
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 2, &perm_shape));
|
||||||
|
if (!c->RankKnown(perm_shape)) {
|
||||||
|
return errors::Internal("permutation has an unknown rank.");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Each batch component of permutation must have the same number of
|
||||||
|
// elements as number of rows of sparse_matrix.
|
||||||
|
TF_RETURN_IF_ERROR(c->Merge(n, c->Dim(perm_shape, -1), &n));
|
||||||
|
ShapeHandle matrix_batch_shape;
|
||||||
|
ShapeHandle perm_batch_shape;
|
||||||
|
|
||||||
|
// Make the common batch subshape.
|
||||||
|
TF_RETURN_IF_ERROR(c->Subshape(matrix_shape, 0, -2, &matrix_batch_shape));
|
||||||
|
TF_RETURN_IF_ERROR(c->Subshape(perm_shape, 0, -1, &perm_shape));
|
||||||
|
// Make sure the batch dimensions match between sparse_matrix and
|
||||||
|
// permutation.
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
c->Merge(matrix_batch_shape, perm_batch_shape, &matrix_batch_shape));
|
||||||
|
|
||||||
|
ShapeHandle out = matrix_shape;
|
||||||
|
c->set_output_handle_shapes_and_types(
|
||||||
|
0, {ShapeAndType{out, sparse_matrix_shape_and_type.dtype}});
|
||||||
|
c->set_output(0, c->Scalar());
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
});
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
369
tensorflow/core/ops/sparse_csr_matrix_ops_test.cc
Normal file
369
tensorflow/core/ops/sparse_csr_matrix_ops_test.cc
Normal file
@ -0,0 +1,369 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/node_def_builder.h"
|
||||||
|
#include "tensorflow/core/framework/node_def_util.h"
|
||||||
|
#include "tensorflow/core/framework/op.h"
|
||||||
|
#include "tensorflow/core/framework/shape_inference.h"
|
||||||
|
#include "tensorflow/core/framework/shape_inference_testutil.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
#include "tensorflow/core/public/version.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
TEST(SparseMatrixOpsTest, SparseTensorToCSRSparseMatrix_ShapeFn) {
|
||||||
|
ShapeInferenceTestOp op("SparseTensorToCSRSparseMatrix");
|
||||||
|
(*op.node_def.mutable_attr())["T"].set_type(DT_FLOAT);
|
||||||
|
op.input_tensors.resize(3);
|
||||||
|
// inputs: indices, values, dense_shape
|
||||||
|
INFER_ERROR("Expected a known rank", op, "?;?;?");
|
||||||
|
INFER_ERROR("either 2 or 3", op, "[?,4];?;?");
|
||||||
|
INFER_OK(op, "[?,2];?;?", "[]");
|
||||||
|
INFER_OK(op, "[?,3];?;?", "[]");
|
||||||
|
Tensor dense_shape_t = test::AsTensor<int64>({5, 6});
|
||||||
|
op.input_tensors[2] = &dense_shape_t;
|
||||||
|
INFER_ERROR("Shape must be rank 3 but is rank 2 for", op, "[?,3];?;?");
|
||||||
|
INFER_OK(op, "[?,2];?;?", "[]");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(SparseMatrixOpsTest, CSRSparseMatrixToSparseTensor_ShapeFn) {
|
||||||
|
ShapeInferenceTestOp op("CSRSparseMatrixToSparseTensor");
|
||||||
|
std::vector<ShapeInferenceTestOp::ShapeAndType> shapes_and_types(1);
|
||||||
|
shapes_and_types[0].second = DT_FLOAT;
|
||||||
|
op.input_resource_handle_shapes_and_types.push_back(&shapes_and_types);
|
||||||
|
// outputs: indices, values, dense_shape
|
||||||
|
shapes_and_types[0].first = "[4,5]";
|
||||||
|
INFER_OK(op, "[]", "[?,2];[?];[2]");
|
||||||
|
shapes_and_types[0].first = "[?,?]";
|
||||||
|
INFER_OK(op, "[]", "[?,2];[?];[2]");
|
||||||
|
shapes_and_types[0].first = "[4,5,6]";
|
||||||
|
INFER_OK(op, "[]", "[?,3];[?];[3]");
|
||||||
|
shapes_and_types[0].first = "[?,?,?]";
|
||||||
|
INFER_OK(op, "[]", "[?,3];[?];[3]");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(SparseMatrixOpsTest, DenseToCSRSparseMatrix_ShapeFn) {
|
||||||
|
ShapeInferenceTestOp op("DenseToCSRSparseMatrix");
|
||||||
|
(*op.node_def.mutable_attr())["T"].set_type(DT_FLOAT);
|
||||||
|
INFER_ERROR("Expected a known rank", op, "?;?");
|
||||||
|
INFER_ERROR("either 2 or 3", op, "[?];?");
|
||||||
|
INFER_OK(op, "[?,?];[?,2]", "[]");
|
||||||
|
INFER_OK(op, "[?,?,?];[?,3]", "[]");
|
||||||
|
INFER_ERROR("indices.shape[1] must match rank of dense; saw: 2 vs. 3", op,
|
||||||
|
"[?,?,?];[?,2]");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(SparseMatrixOpsTest, CSRSparseMatrixToDense_ShapeFn) {
|
||||||
|
ShapeInferenceTestOp op("CSRSparseMatrixToDense");
|
||||||
|
std::vector<ShapeInferenceTestOp::ShapeAndType> shapes_and_types(1);
|
||||||
|
shapes_and_types[0].second = DT_FLOAT;
|
||||||
|
op.input_resource_handle_shapes_and_types.push_back(&shapes_and_types);
|
||||||
|
// outputs: dense
|
||||||
|
shapes_and_types[0].first = "[?,?]";
|
||||||
|
INFER_OK(op, "[]", "[?,?]");
|
||||||
|
shapes_and_types[0].first = "[?,?,?]";
|
||||||
|
INFER_OK(op, "[]", "[?,?,?]");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(SparseMatrixOpsTest, CSRSparseMatrixComponents_ShapeFn) {
|
||||||
|
ShapeInferenceTestOp op("CSRSparseMatrixComponents");
|
||||||
|
std::vector<ShapeInferenceTestOp::ShapeAndType> shapes_and_types(1);
|
||||||
|
shapes_and_types[0].second = DT_FLOAT;
|
||||||
|
op.input_resource_handle_shapes_and_types.push_back(&shapes_and_types);
|
||||||
|
op.input_resource_handle_shapes_and_types.push_back(nullptr);
|
||||||
|
// inputs: csr_sparse_matrix, index
|
||||||
|
// outputs: row_ptrs, col_inds, values
|
||||||
|
shapes_and_types[0].first = "[4,5]";
|
||||||
|
INFER_OK(op, "[];[]", "[5];[?];[?]");
|
||||||
|
shapes_and_types[0].first = "[?,?]";
|
||||||
|
INFER_OK(op, "[];[]", "[?];[?];[?]");
|
||||||
|
shapes_and_types[0].first = "[19,34,55]";
|
||||||
|
INFER_OK(op, "[];[]", "[35];[?];[?]");
|
||||||
|
shapes_and_types[0].first = "[?,?,?]";
|
||||||
|
INFER_OK(op, "[];[]", "[?];[?];[?]");
|
||||||
|
shapes_and_types[0].first = "[?,?,?]";
|
||||||
|
INFER_ERROR("index must be a scalar", op, "[];?");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(SparseMatrixOpsTest, SparseMatrixMatMul_ShapeFn) {
|
||||||
|
ShapeInferenceTestOp op("SparseMatrixMatMul");
|
||||||
|
std::vector<ShapeInferenceTestOp::ShapeAndType> a_shapes_and_types(1);
|
||||||
|
a_shapes_and_types[0].second = DT_FLOAT;
|
||||||
|
op.input_resource_handle_shapes_and_types.push_back(&a_shapes_and_types);
|
||||||
|
op.input_resource_handle_shapes_and_types.push_back(nullptr);
|
||||||
|
auto set_options = [&op](bool transpose_a, bool transpose_b, bool adjoint_a,
|
||||||
|
bool adjoint_b, bool transpose_output) {
|
||||||
|
TF_ASSERT_OK(NodeDefBuilder("test", "SparseMatrixMatMul")
|
||||||
|
.Input("a", 0, DT_VARIANT)
|
||||||
|
.Input("b", 1, DT_FLOAT)
|
||||||
|
.Attr("transpose_a", transpose_a)
|
||||||
|
.Attr("transpose_b", transpose_b)
|
||||||
|
.Attr("adjoint_a", adjoint_a)
|
||||||
|
.Attr("adjoint_b", adjoint_b)
|
||||||
|
.Attr("transpose_output", transpose_output)
|
||||||
|
.Finalize(&op.node_def));
|
||||||
|
};
|
||||||
|
// inputs: a <CSR>, b <T>
|
||||||
|
// output: matmul(a, b)
|
||||||
|
set_options(false, false, false, false, false /*transpose_output*/);
|
||||||
|
a_shapes_and_types[0].first = "?";
|
||||||
|
INFER_ERROR("a has an unknown rank", op, "[];?");
|
||||||
|
a_shapes_and_types[0].first = "[?]";
|
||||||
|
INFER_ERROR("must be at least rank 2 but is rank 1", op, "[];?");
|
||||||
|
a_shapes_and_types[0].first = "[?,?]";
|
||||||
|
INFER_OK(op, "[];?", "[?,?]");
|
||||||
|
a_shapes_and_types[0].first = "[?,?,?]";
|
||||||
|
INFER_OK(op, "[];?", "[?,?,?]");
|
||||||
|
a_shapes_and_types[0].first = "[?,3,?]";
|
||||||
|
INFER_OK(op, "[];[?,?,?]", "[?,3,d1_2]");
|
||||||
|
a_shapes_and_types[0].first = "[?,3,?]";
|
||||||
|
INFER_OK(op, "[];[?,?,4]", "[?,3,d1_2]"); // [B,3,?] . [B,?,4]
|
||||||
|
a_shapes_and_types[0].first = "[?,?,6]";
|
||||||
|
INFER_OK(op, "[];[?,6,?]", "[?,?,d1_2]"); // [B,?,6] . [B,6,?]
|
||||||
|
a_shapes_and_types[0].first = "[?,?,5]";
|
||||||
|
INFER_ERROR("must be equal, but are 5 and 6 for", op, "[];[?,6,?]");
|
||||||
|
|
||||||
|
set_options(false, false, false, false, true /*transpose_output*/);
|
||||||
|
a_shapes_and_types[0].first = "[?,3,?]";
|
||||||
|
INFER_OK(op, "[];[?,?,4]", "[?,d1_2,3]");
|
||||||
|
a_shapes_and_types[0].first = "[3,?]";
|
||||||
|
INFER_OK(op, "[];[?,4]", "[d1_1,3]");
|
||||||
|
|
||||||
|
set_options(/*transpose_a=*/true, /*transpose_b=*/true,
|
||||||
|
/*adjoint_a=*/false, /*adjoint_b=*/false,
|
||||||
|
false /*transpose_output*/);
|
||||||
|
// t([B,W,X]) . t([B,Y,Z]) => [B,X,Y]
|
||||||
|
a_shapes_and_types[0].first = "[?,?,?]";
|
||||||
|
INFER_OK(op, "[];[?,?,?]", "[?,?,d1_1]");
|
||||||
|
|
||||||
|
set_options(/*transpose_a=*/false, /*transpose_b=*/false,
|
||||||
|
/*adjoint_a=*/true, /*adjoint_b=*/true,
|
||||||
|
false /*transpose_output*/);
|
||||||
|
// adj([B,W,X]) . adj([B,Y,Z]) => [B,X,Y]
|
||||||
|
a_shapes_and_types[0].first = "[?,?,?]";
|
||||||
|
INFER_OK(op, "[];[?,?,?]", "[?,?,d1_1]");
|
||||||
|
|
||||||
|
set_options(true /*transpose_a*/, true /*transpose_b*/,
|
||||||
|
/*adjoint_a=*/false, /*adjoint_b=*/false,
|
||||||
|
true /*transpose_output*/);
|
||||||
|
// t(t([B,W,X]) . t([B,Y,Z])) => [B,Y,X]
|
||||||
|
a_shapes_and_types[0].first = "[?,?,?]";
|
||||||
|
INFER_OK(op, "[];[?,?,?]", "[?,d1_1,?]");
|
||||||
|
|
||||||
|
set_options(/*transpose_a=*/true, /*transpose_b=*/false,
|
||||||
|
/*adjoint_a=*/true, /*adjoint_b=*/true,
|
||||||
|
false /*transpose_output*/);
|
||||||
|
a_shapes_and_types[0].first = "[?,?,?]";
|
||||||
|
INFER_ERROR("Only one of adjoint_a and transpose_a", op, "[];[?,?,?]");
|
||||||
|
set_options(/*transpose_a=*/false, /*transpose_b=*/true,
|
||||||
|
/*adjoint_a=*/true, /*adjoint_b=*/true,
|
||||||
|
false /*transpose_output*/);
|
||||||
|
a_shapes_and_types[0].first = "[?,?,?]";
|
||||||
|
INFER_ERROR("Only one of adjoint_b and transpose_b", op, "[];[?,?,?]");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(SparseMatrixOpsTest, SparseMatrixAdd_ShapeFn) {
|
||||||
|
// inputs: a <CSR>, b <CSR>, alpha <scalar>, beta <scalar>
|
||||||
|
// output: alpha * a + beta * b
|
||||||
|
ShapeInferenceTestOp op("SparseMatrixAdd");
|
||||||
|
std::vector<ShapeInferenceTestOp::ShapeAndType> a_shapes_and_types(1);
|
||||||
|
std::vector<ShapeInferenceTestOp::ShapeAndType> b_shapes_and_types(1);
|
||||||
|
a_shapes_and_types[0].second = DT_FLOAT;
|
||||||
|
b_shapes_and_types[0].second = DT_FLOAT;
|
||||||
|
op.input_resource_handle_shapes_and_types.push_back(&a_shapes_and_types);
|
||||||
|
op.input_resource_handle_shapes_and_types.push_back(&b_shapes_and_types);
|
||||||
|
op.input_resource_handle_shapes_and_types.push_back(nullptr);
|
||||||
|
op.input_resource_handle_shapes_and_types.push_back(nullptr);
|
||||||
|
auto set_shapes = [&a_shapes_and_types, &b_shapes_and_types](
|
||||||
|
const string& a_shape, const string& b_shape) {
|
||||||
|
a_shapes_and_types[0].first = a_shape;
|
||||||
|
b_shapes_and_types[0].first = b_shape;
|
||||||
|
};
|
||||||
|
// TODO(ebrevdo): Update shape_inference_testutil to be able to properly test
|
||||||
|
// output handle shapes and types.
|
||||||
|
set_shapes("[?,?]", "[?,?]");
|
||||||
|
INFER_OK(op, "[];[];?;?", "[]"); // output handle: [?,?]
|
||||||
|
set_shapes("[?,?,?]", "[?,?,?]");
|
||||||
|
INFER_OK(op, "[];[];?;?", "[]"); // output handle: [?,?,?]
|
||||||
|
set_shapes("[3,4]", "[3,4]");
|
||||||
|
INFER_OK(op, "[];[];?;?", "[]"); // output handle: [3,4]
|
||||||
|
set_shapes("[3,4,5]", "[3,4,5]");
|
||||||
|
INFER_OK(op, "[];[];?;?", "[]"); // output handle: [3,4,5]
|
||||||
|
set_shapes("[?,?,?]", "[?,?,?]");
|
||||||
|
INFER_OK(op, "[];[];[];[]", "[]"); // output handle: [?,?,?]
|
||||||
|
// non-scalar beta.
|
||||||
|
set_shapes("[?,?]", "[?,?]");
|
||||||
|
INFER_ERROR("must be rank 0 but is rank 1", op, "[];[];?;[?]");
|
||||||
|
// unknown rank b.
|
||||||
|
set_shapes("[?,?,?]", "?");
|
||||||
|
INFER_ERROR("b has an unknown rank", op, "[];[];?;?");
|
||||||
|
// different ranks of a and b.
|
||||||
|
set_shapes("[?,?,?]", "[?,?]");
|
||||||
|
INFER_ERROR("must be equal", op, "[];[];?;?");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(SparseMatrixOpsTest, SparseMatrixSparseMatMul_ShapeFn) {
|
||||||
|
ShapeInferenceTestOp op("SparseMatrixSparseMatMul");
|
||||||
|
std::vector<ShapeInferenceTestOp::ShapeAndType> a_shapes_and_types(1);
|
||||||
|
std::vector<ShapeInferenceTestOp::ShapeAndType> b_shapes_and_types(1);
|
||||||
|
a_shapes_and_types[0].second = DT_FLOAT;
|
||||||
|
b_shapes_and_types[0].second = DT_FLOAT;
|
||||||
|
op.input_resource_handle_shapes_and_types.push_back(&a_shapes_and_types);
|
||||||
|
op.input_resource_handle_shapes_and_types.push_back(&b_shapes_and_types);
|
||||||
|
auto set_shapes = [&a_shapes_and_types, &b_shapes_and_types](
|
||||||
|
const string& a_shape, const string& b_shape) {
|
||||||
|
a_shapes_and_types[0].first = a_shape;
|
||||||
|
b_shapes_and_types[0].first = b_shape;
|
||||||
|
};
|
||||||
|
auto set_options = [&op](bool transpose_a, bool transpose_b, bool adjoint_a,
|
||||||
|
bool adjoint_b) {
|
||||||
|
TF_ASSERT_OK(NodeDefBuilder("test", "SparseMatrixMatMul")
|
||||||
|
.Input("a", 0, DT_VARIANT)
|
||||||
|
.Input("b", 1, DT_FLOAT)
|
||||||
|
.Attr("transpose_a", transpose_a)
|
||||||
|
.Attr("transpose_b", transpose_b)
|
||||||
|
.Attr("adjoint_a", adjoint_a)
|
||||||
|
.Attr("adjoint_b", adjoint_b)
|
||||||
|
.Finalize(&op.node_def));
|
||||||
|
};
|
||||||
|
// inputs: a <CSR>, b <CSR>
|
||||||
|
// output: matmul(a, b) <CSR>
|
||||||
|
set_options(false, false, false, false);
|
||||||
|
set_shapes("?", "?");
|
||||||
|
INFER_ERROR("has an unknown rank", op, "[];[]");
|
||||||
|
set_shapes("[?]", "[?,?]");
|
||||||
|
INFER_ERROR("must be at least rank 2 but is rank 1", op, "[];[]");
|
||||||
|
set_shapes("[?,?]", "[?,?]");
|
||||||
|
INFER_OK(op, "[];[]", "[]"); // [d0_0,d1_1]"
|
||||||
|
set_shapes("[?,?,?]", "[?,?]");
|
||||||
|
INFER_ERROR("must be equal rank, but are", op, "[];[]");
|
||||||
|
set_shapes("[?,?,?]", "[?,?,?]");
|
||||||
|
INFER_OK(op, "[];[]", "[]"); // "[d0_0,d0_1,d1_2]"
|
||||||
|
set_shapes("[?,3,?]", "[?,?,?]");
|
||||||
|
INFER_OK(op, "[];[]", "[]"); // "[d0_0,d0_1,d1_2]"
|
||||||
|
set_shapes("[?,3,?]", "[?,?,4]");
|
||||||
|
INFER_OK(op, "[];[]", "[]"); // [d0_0,d0_1,d1_2]"
|
||||||
|
set_shapes("[?,?,6]", "[?,6,?]");
|
||||||
|
INFER_OK(op, "[];[]", "[]"); // "[d0_0,d0_1,d1_2]"
|
||||||
|
set_shapes("[?,?,5]", "[?,6,?]");
|
||||||
|
INFER_ERROR("must be equal, but are 5 and 6 for", op, "[];[]");
|
||||||
|
|
||||||
|
set_options(/*transpose_a=*/true, /*transpose_b=*/true, /*adjoint_a=*/false,
|
||||||
|
/*adjoint_b=*/false);
|
||||||
|
// t([B,W,X]) . t([B,Y,Z]) => [B,X,Y]
|
||||||
|
set_shapes("[?,?,?]", "[?,?,?]");
|
||||||
|
INFER_OK(op, "[];[]", "[]"); // [d0_0,d0_2,d1_1]"
|
||||||
|
|
||||||
|
set_options(/*transpose_a=*/false, /*transpose_b=*/false, /*adjoint_a=*/true,
|
||||||
|
/*adjoint_b=*/true);
|
||||||
|
// adj([B,W,X]) . adj([B,Y,Z]) => [B,X,Y]
|
||||||
|
set_shapes("[?,?,?]", "[?,?,?]");
|
||||||
|
INFER_OK(op, "[];[]", "[]"); // "[d0_0,d0_2,d1_1]"
|
||||||
|
|
||||||
|
set_options(/*transpose_a=*/true, /*transpose_b=*/false,
|
||||||
|
/*adjoint_a=*/true, /*adjoint_b=*/true);
|
||||||
|
set_shapes("[?,?,?]", "[?,?,?]");
|
||||||
|
INFER_ERROR("Only one of adjoint_a and transpose_a", op, "[];[]");
|
||||||
|
set_options(/*transpose_a=*/false, /*transpose_b=*/true,
|
||||||
|
/*adjoint_a=*/true, /*adjoint_b=*/true);
|
||||||
|
set_shapes("[?,?,?]", "[?,?,?]");
|
||||||
|
INFER_ERROR("Only one of adjoint_b and transpose_b", op, "[];[]");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(SparseMatrixOpsTest, SparseMatrixTranspose_ShapeFn) {
|
||||||
|
ShapeInferenceTestOp op("SparseMatrixTranspose");
|
||||||
|
// inputs: input
|
||||||
|
// outputs: output
|
||||||
|
std::vector<ShapeInferenceTestOp::ShapeAndType> shapes_and_types(1);
|
||||||
|
shapes_and_types[0].second = DT_FLOAT;
|
||||||
|
op.input_resource_handle_shapes_and_types.push_back(&shapes_and_types);
|
||||||
|
shapes_and_types[0].first = "[3,4,5]";
|
||||||
|
INFER_OK(op, "[]", "[]"); // [3,5,4]"
|
||||||
|
shapes_and_types[0].first = "[3,4]";
|
||||||
|
INFER_OK(op, "[]", "[]"); // "[4, 3]";
|
||||||
|
shapes_and_types[0].first = "?";
|
||||||
|
INFER_ERROR("input has an unknown rank", op, "[]");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(SparseMatrixOpsTest, SparseMatrixSoftmax_ShapeFn) {
|
||||||
|
ShapeInferenceTestOp op("SparseMatrixSoftmax");
|
||||||
|
// inputs: logits
|
||||||
|
// outputs: softmax
|
||||||
|
std::vector<ShapeInferenceTestOp::ShapeAndType> shapes_and_types(1);
|
||||||
|
shapes_and_types[0].second = DT_FLOAT;
|
||||||
|
op.input_resource_handle_shapes_and_types.push_back(&shapes_and_types);
|
||||||
|
shapes_and_types[0].first = "[?,?,?]";
|
||||||
|
INFER_OK(op, "[]", "[]"); // "in0"
|
||||||
|
shapes_and_types[0].first = "[?,?]";
|
||||||
|
INFER_OK(op, "[]", "[]"); // "in0"
|
||||||
|
shapes_and_types[0].first = "?";
|
||||||
|
INFER_ERROR("logits has an unknown rank", op, "[]");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(SparseMatrixOpsTest, SparseMatrixSoftmaxGrad_ShapeFn) {
|
||||||
|
ShapeInferenceTestOp op("SparseMatrixSoftmaxGrad");
|
||||||
|
// inputs: softmax, grad_softmax
|
||||||
|
// outputs: gradient
|
||||||
|
std::vector<ShapeInferenceTestOp::ShapeAndType> a_shapes_and_types(1);
|
||||||
|
std::vector<ShapeInferenceTestOp::ShapeAndType> b_shapes_and_types(1);
|
||||||
|
a_shapes_and_types[0].second = DT_FLOAT;
|
||||||
|
b_shapes_and_types[0].second = DT_FLOAT;
|
||||||
|
op.input_resource_handle_shapes_and_types.push_back(&a_shapes_and_types);
|
||||||
|
op.input_resource_handle_shapes_and_types.push_back(&b_shapes_and_types);
|
||||||
|
auto set_shapes = [&a_shapes_and_types, &b_shapes_and_types](
|
||||||
|
const string& a_shape, const string& b_shape) {
|
||||||
|
a_shapes_and_types[0].first = a_shape;
|
||||||
|
b_shapes_and_types[0].first = b_shape;
|
||||||
|
};
|
||||||
|
set_shapes("[?,?,?]", "[?,?,?]");
|
||||||
|
INFER_OK(op, "[];[]", "[]"); // "in0"
|
||||||
|
set_shapes("[?,?]", "[?,?]");
|
||||||
|
INFER_OK(op, "[];[]", "[]"); // "in0"
|
||||||
|
set_shapes("[3,4]", "[5,6]");
|
||||||
|
INFER_ERROR("Dimension 0 in both shapes must be equal, but are 3 and 5", op,
|
||||||
|
"[];[]");
|
||||||
|
set_shapes("?", "[?,?]");
|
||||||
|
INFER_ERROR("softmax has an unknown rank", op, "[];[]");
|
||||||
|
set_shapes("[?,?,?]", "?");
|
||||||
|
INFER_ERROR("grad_softmax has an unknown rank", op, "[];[]");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(SparseMatrixOpsTest, SparseMatrixMul_ShapeFn) {
|
||||||
|
ShapeInferenceTestOp op("SparseMatrixMul");
|
||||||
|
// inputs: a <CSR>, b <dense>
|
||||||
|
// output: a * b
|
||||||
|
std::vector<ShapeInferenceTestOp::ShapeAndType> shapes_and_types(1);
|
||||||
|
shapes_and_types[0].second = DT_FLOAT;
|
||||||
|
op.input_resource_handle_shapes_and_types.push_back(&shapes_and_types);
|
||||||
|
op.input_resource_handle_shapes_and_types.push_back(nullptr);
|
||||||
|
shapes_and_types[0].first = "[3,4]";
|
||||||
|
INFER_OK(op, "[];[]", "[]"); // "[3,4]"
|
||||||
|
shapes_and_types[0].first = "[5,3,4]";
|
||||||
|
INFER_OK(op, "[];[?,1,1]", "[]"); // "[5,3,4]"
|
||||||
|
// b not scalar, doesn't match a.
|
||||||
|
shapes_and_types[0].first = "[?,?,?]";
|
||||||
|
INFER_ERROR("b must be a scalar or shaped [batch_size, 1, 1]", op,
|
||||||
|
"[];[3,4]");
|
||||||
|
shapes_and_types[0].first = "[3,4]";
|
||||||
|
INFER_ERROR("b must be a scalar or shaped", op, "[];[3,4]");
|
||||||
|
shapes_and_types[0].first = "[3,4,5]";
|
||||||
|
INFER_ERROR("b must be a scalar or shaped", op, "[];[3,4,5]");
|
||||||
|
shapes_and_types[0].first = "[3,4,5]";
|
||||||
|
INFER_ERROR("must be equal, but are 3 and 4", op, "[];[4,1,1]");
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
@ -38,8 +38,11 @@ tensorflow/third_party/eigen3/Eigen/Cholesky
|
|||||||
tensorflow/third_party/eigen3/Eigen/Core
|
tensorflow/third_party/eigen3/Eigen/Core
|
||||||
tensorflow/third_party/eigen3/Eigen/Eigenvalues
|
tensorflow/third_party/eigen3/Eigen/Eigenvalues
|
||||||
tensorflow/third_party/eigen3/Eigen/LU
|
tensorflow/third_party/eigen3/Eigen/LU
|
||||||
|
tensorflow/third_party/eigen3/Eigen/OrderingMethods
|
||||||
tensorflow/third_party/eigen3/Eigen/QR
|
tensorflow/third_party/eigen3/Eigen/QR
|
||||||
tensorflow/third_party/eigen3/Eigen/SVD
|
tensorflow/third_party/eigen3/Eigen/SVD
|
||||||
|
tensorflow/third_party/eigen3/Eigen/SparseCholesky
|
||||||
|
tensorflow/third_party/eigen3/Eigen/SparseCore
|
||||||
tensorflow/third_party/eigen3/LICENSE
|
tensorflow/third_party/eigen3/LICENSE
|
||||||
tensorflow/third_party/eigen3/gpu_packet_math.patch
|
tensorflow/third_party/eigen3/gpu_packet_math.patch
|
||||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint
|
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint
|
||||||
|
@ -184,6 +184,7 @@ py_library(
|
|||||||
"//tensorflow/python/module",
|
"//tensorflow/python/module",
|
||||||
"//tensorflow/python/ops/distributions",
|
"//tensorflow/python/ops/distributions",
|
||||||
"//tensorflow/python/ops/linalg",
|
"//tensorflow/python/ops/linalg",
|
||||||
|
"//tensorflow/python/ops/linalg/sparse",
|
||||||
"//tensorflow/python/ops/losses",
|
"//tensorflow/python/ops/losses",
|
||||||
"//tensorflow/python/ops/parallel_for",
|
"//tensorflow/python/ops/parallel_for",
|
||||||
"//tensorflow/python/ops/ragged",
|
"//tensorflow/python/ops/ragged",
|
||||||
@ -2875,6 +2876,7 @@ py_library(
|
|||||||
":tensor_array_ops",
|
":tensor_array_ops",
|
||||||
":unconnected_gradients",
|
":unconnected_gradients",
|
||||||
":util",
|
":util",
|
||||||
|
"//tensorflow/python/ops/linalg/sparse",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -3861,6 +3863,7 @@ py_library(
|
|||||||
"//tensorflow/python/eager:wrap_function",
|
"//tensorflow/python/eager:wrap_function",
|
||||||
"//tensorflow/python/ops/distributions",
|
"//tensorflow/python/ops/distributions",
|
||||||
"//tensorflow/python/ops/linalg",
|
"//tensorflow/python/ops/linalg",
|
||||||
|
"//tensorflow/python/ops/linalg/sparse",
|
||||||
"//tensorflow/python/ops/ragged",
|
"//tensorflow/python/ops/ragged",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -8,7 +8,7 @@ package(
|
|||||||
licenses = ["notice"], # Apache 2.0
|
licenses = ["notice"], # Apache 2.0
|
||||||
)
|
)
|
||||||
|
|
||||||
# CPU only tests should use tf_py_test, GPU tests use cuda_py_test
|
# CPU-only tests should use tf_py_test, GPU tests use cuda_py_test
|
||||||
# Please avoid the py_tests and cuda_py_tests (plural) while we
|
# Please avoid the py_tests and cuda_py_tests (plural) while we
|
||||||
# fix the shared/overbroad dependencies.
|
# fix the shared/overbroad dependencies.
|
||||||
|
|
||||||
@ -3878,3 +3878,57 @@ cuda_py_test(
|
|||||||
tags = ["no_rocm"],
|
tags = ["no_rocm"],
|
||||||
xla_enable_strict_auto_jit = True,
|
xla_enable_strict_auto_jit = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cuda_py_test(
|
||||||
|
name = "sparse_csr_matrix_ops_test",
|
||||||
|
size = "medium",
|
||||||
|
srcs = ["sparse_csr_matrix_ops_test.py"],
|
||||||
|
additional_deps = [
|
||||||
|
"//tensorflow/python/ops/linalg/sparse",
|
||||||
|
"//tensorflow/python/ops/linalg/sparse:gen_sparse_csr_matrix_ops",
|
||||||
|
],
|
||||||
|
main = "sparse_csr_matrix_ops_test.py",
|
||||||
|
)
|
||||||
|
|
||||||
|
cuda_py_test(
|
||||||
|
name = "csr_sparse_matrix_test",
|
||||||
|
size = "medium",
|
||||||
|
srcs = ["csr_sparse_matrix_test.py"],
|
||||||
|
additional_deps = [
|
||||||
|
"//tensorflow/python/ops/linalg/sparse",
|
||||||
|
],
|
||||||
|
main = "csr_sparse_matrix_test.py",
|
||||||
|
)
|
||||||
|
|
||||||
|
cuda_py_test(
|
||||||
|
name = "sparse_csr_matrix_grad_test",
|
||||||
|
size = "medium",
|
||||||
|
srcs = ["sparse_csr_matrix_grad_test.py"],
|
||||||
|
additional_deps = [
|
||||||
|
"//tensorflow/python/ops/linalg/sparse",
|
||||||
|
],
|
||||||
|
main = "sparse_csr_matrix_grad_test.py",
|
||||||
|
shard_count = 50,
|
||||||
|
)
|
||||||
|
|
||||||
|
cuda_py_test(
|
||||||
|
name = "sparse_csr_matrix_dense_mat_mul_grad_test",
|
||||||
|
size = "medium",
|
||||||
|
srcs = ["sparse_csr_matrix_dense_mat_mul_grad_test.py"],
|
||||||
|
additional_deps = [
|
||||||
|
"//tensorflow/python/ops/linalg/sparse",
|
||||||
|
],
|
||||||
|
main = "sparse_csr_matrix_dense_mat_mul_grad_test.py",
|
||||||
|
shard_count = 50,
|
||||||
|
)
|
||||||
|
|
||||||
|
cuda_py_test(
|
||||||
|
name = "sparse_csr_matrix_sparse_mat_mul_grad_test",
|
||||||
|
size = "medium",
|
||||||
|
srcs = ["sparse_csr_matrix_sparse_mat_mul_grad_test.py"],
|
||||||
|
additional_deps = [
|
||||||
|
"//tensorflow/python/ops/linalg/sparse",
|
||||||
|
],
|
||||||
|
main = "sparse_csr_matrix_sparse_mat_mul_grad_test.py",
|
||||||
|
shard_count = 50,
|
||||||
|
)
|
||||||
|
266
tensorflow/python/kernel_tests/csr_sparse_matrix_test.py
Normal file
266
tensorflow/python/kernel_tests/csr_sparse_matrix_test.py
Normal file
@ -0,0 +1,266 @@
|
|||||||
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""CSR sparse matrix tests."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import itertools
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import sparse_tensor
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
|
from tensorflow.python.ops.linalg.sparse import sparse_csr_matrix_ops
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
class CSRSparseMatrixTest(test.TestCase):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls): # pylint: disable=g-missing-super-call
|
||||||
|
cls._gpu_available = test_util.is_gpu_available()
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testConstructorFromSparseTensor(self):
|
||||||
|
if not self._gpu_available:
|
||||||
|
return
|
||||||
|
|
||||||
|
a_indices = np.array([[0, 0], [2, 3], [2, 4], [3, 0]])
|
||||||
|
a_values = [1.0, 5.0, -1.0, -2.0]
|
||||||
|
a_dense_shape = [5, 6]
|
||||||
|
|
||||||
|
a_st = sparse_tensor.SparseTensor(a_indices, a_values, a_dense_shape)
|
||||||
|
a_st = math_ops.cast(a_st, dtypes.float32)
|
||||||
|
a_sm = sparse_csr_matrix_ops.CSRSparseMatrix(a_st)
|
||||||
|
self.assertEqual(a_sm.shape, a_dense_shape)
|
||||||
|
|
||||||
|
a_st_rt = a_sm.to_sparse_tensor()
|
||||||
|
a_st_rt = self.evaluate(a_st_rt)
|
||||||
|
|
||||||
|
self.assertAllEqual(a_indices, a_st_rt.indices)
|
||||||
|
self.assertAllClose(a_values, a_st_rt.values)
|
||||||
|
self.assertAllEqual(a_dense_shape, a_st_rt.dense_shape)
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testConstructorFromDenseTensorNoIndices(self):
|
||||||
|
if not self._gpu_available:
|
||||||
|
return
|
||||||
|
|
||||||
|
sparsify = lambda m: m * (m > 0)
|
||||||
|
dense_shape = [5, 7, 13]
|
||||||
|
a_mats = sparsify(np.random.randn(*dense_shape)).astype(np.float32)
|
||||||
|
|
||||||
|
a_sm = sparse_csr_matrix_ops.CSRSparseMatrix(a_mats)
|
||||||
|
self.assertEqual(a_sm.shape, a_mats.shape)
|
||||||
|
|
||||||
|
a_sm_rt = a_sm.to_dense()
|
||||||
|
a_sm_nnz = a_sm.nnz()
|
||||||
|
a_sm_nnz, a_sm_rt = self.evaluate([a_sm_nnz, a_sm_rt])
|
||||||
|
|
||||||
|
# Count number of nonzero entries for each batch using bincount.
|
||||||
|
nz = np.bincount(a_mats.nonzero()[0], minlength=a_mats.shape[0])
|
||||||
|
self.assertAllEqual(nz, a_sm_nnz)
|
||||||
|
self.assertAllClose(a_mats, a_sm_rt)
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testConstructorFromDenseTensorWithIndices(self):
|
||||||
|
if not self._gpu_available:
|
||||||
|
return
|
||||||
|
|
||||||
|
dense_shape = [5, 7, 13]
|
||||||
|
a_mats = np.random.randn(*dense_shape).astype(np.float32)
|
||||||
|
indices = np.array([[0, 0, 0],
|
||||||
|
[1, 0, 0]], dtype=np.int64)
|
||||||
|
|
||||||
|
a_sm = sparse_csr_matrix_ops.CSRSparseMatrix(a_mats, indices=indices)
|
||||||
|
self.assertEqual(a_sm.shape, a_mats.shape)
|
||||||
|
|
||||||
|
a_sm_st = a_sm.to_sparse_tensor()
|
||||||
|
a_sm_st = self.evaluate(a_sm_st)
|
||||||
|
|
||||||
|
# Count number of nonzero entries for each batch using bincount.
|
||||||
|
self.assertAllEqual(indices, a_sm_st.indices)
|
||||||
|
self.assertAllEqual(dense_shape, a_sm.shape)
|
||||||
|
self.assertAllEqual(dense_shape, a_sm_st.dense_shape)
|
||||||
|
self.assertAllClose([a_mats[tuple(x)] for x in indices], a_sm_st.values)
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testConj(self):
|
||||||
|
if not self._gpu_available:
|
||||||
|
return
|
||||||
|
|
||||||
|
sparsify = lambda m: m * (m.real > 0)
|
||||||
|
dense_shape = [5, 7, 13]
|
||||||
|
a_mats = sparsify(
|
||||||
|
(np.random.randn(*dense_shape) + 1.j * np.random.randn(*dense_shape))
|
||||||
|
.astype(np.complex64))
|
||||||
|
a_sm = sparse_csr_matrix_ops.CSRSparseMatrix(a_mats)
|
||||||
|
a_sm_conj = a_sm.conj()
|
||||||
|
self.assertIsInstance(a_sm_conj, sparse_csr_matrix_ops.CSRSparseMatrix)
|
||||||
|
a_sm_conj_dense = a_sm_conj.to_dense()
|
||||||
|
a_sm_conj_dense = self.evaluate(a_sm_conj_dense)
|
||||||
|
self.assertAllClose(a_mats.conj(), a_sm_conj_dense)
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testTranspose(self):
|
||||||
|
if not self._gpu_available:
|
||||||
|
return
|
||||||
|
|
||||||
|
for conjugate in False, True:
|
||||||
|
sparsify = lambda m: m * (m > 0)
|
||||||
|
dense_shape = [5, 7, 13]
|
||||||
|
a_mats = sparsify((np.random.randn(*dense_shape) +
|
||||||
|
1.j * np.random.randn(*dense_shape))).astype(
|
||||||
|
np.complex64)
|
||||||
|
expected = np.transpose(a_mats, (0, 2, 1))
|
||||||
|
if conjugate:
|
||||||
|
expected = np.conj(expected)
|
||||||
|
a_sm = sparse_csr_matrix_ops.CSRSparseMatrix(a_mats)
|
||||||
|
if conjugate:
|
||||||
|
a_sm_t = a_sm.hermitian_transpose()
|
||||||
|
else:
|
||||||
|
a_sm_t = a_sm.transpose()
|
||||||
|
self.assertIsInstance(a_sm_t, sparse_csr_matrix_ops.CSRSparseMatrix)
|
||||||
|
a_sm_t_dense = a_sm_t.to_dense()
|
||||||
|
a_sm_t_dense = self.evaluate(a_sm_t_dense)
|
||||||
|
self.assertAllClose(expected, a_sm_t_dense)
|
||||||
|
|
||||||
|
|
||||||
|
class SparseMatrixMatmulTest(test.TestCase):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls): # pylint: disable=g-missing-super-call
|
||||||
|
cls._gpu_available = test_util.is_gpu_available()
|
||||||
|
|
||||||
|
def _testSparseSparse(self, transpose_a, transpose_b, adjoint_a, adjoint_b):
|
||||||
|
if not self._gpu_available:
|
||||||
|
return
|
||||||
|
sparsify = lambda m: m * (m > 0)
|
||||||
|
dense_shape_a = [5, 13, 7] if transpose_a or adjoint_a else [5, 7, 13]
|
||||||
|
dense_shape_b = [5, 15, 13] if transpose_b or adjoint_b else [5, 13, 15]
|
||||||
|
for dtype in np.float32, np.complex64:
|
||||||
|
a_mats = sparsify((np.random.randn(*dense_shape_a) +
|
||||||
|
1.j * np.random.randn(*dense_shape_a))).astype(dtype)
|
||||||
|
b_mats = sparsify((np.random.randn(*dense_shape_b) +
|
||||||
|
1.j * np.random.randn(*dense_shape_b))).astype(dtype)
|
||||||
|
a_sm = sparse_csr_matrix_ops.CSRSparseMatrix(a_mats)
|
||||||
|
b_sm = sparse_csr_matrix_ops.CSRSparseMatrix(b_mats)
|
||||||
|
c_dense = math_ops.matmul(
|
||||||
|
a_mats,
|
||||||
|
b_mats,
|
||||||
|
transpose_a=transpose_a,
|
||||||
|
transpose_b=transpose_b,
|
||||||
|
adjoint_a=adjoint_a,
|
||||||
|
adjoint_b=adjoint_b)
|
||||||
|
c_sm = sparse_csr_matrix_ops.matmul(
|
||||||
|
a_sm,
|
||||||
|
b_sm,
|
||||||
|
transpose_a=transpose_a,
|
||||||
|
transpose_b=transpose_b,
|
||||||
|
adjoint_a=adjoint_a,
|
||||||
|
adjoint_b=adjoint_b)
|
||||||
|
self.assertIsInstance(c_sm, sparse_csr_matrix_ops.CSRSparseMatrix)
|
||||||
|
c_sm_dense = c_sm.to_dense()
|
||||||
|
c_dense, c_sm_dense = self.evaluate([c_dense, c_sm_dense])
|
||||||
|
self.assertAllClose(c_dense, c_sm_dense)
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testSparseSparse(self):
|
||||||
|
for (t_a, t_b, adj_a, adj_b) in itertools.product(*(([False, True],) * 4)):
|
||||||
|
if (t_a and adj_a) or (t_b and adj_b):
|
||||||
|
continue
|
||||||
|
self._testSparseSparse(t_a, t_b, adj_a, adj_b)
|
||||||
|
|
||||||
|
def _testSparseDense(self, transpose_a, transpose_b, adjoint_a, adjoint_b):
|
||||||
|
if not self._gpu_available:
|
||||||
|
return
|
||||||
|
|
||||||
|
sparsify = lambda m: m * (m > 0)
|
||||||
|
dense_shape_a = [5, 13, 7] if transpose_a or adjoint_a else [5, 7, 13]
|
||||||
|
dense_shape_b = [5, 15, 13] if transpose_b or adjoint_b else [5, 13, 15]
|
||||||
|
for dtype in np.float32, np.complex64:
|
||||||
|
a_mats = sparsify((np.random.randn(*dense_shape_a) +
|
||||||
|
1.j * np.random.randn(*dense_shape_a))).astype(dtype)
|
||||||
|
b_mats = (np.random.randn(*dense_shape_b) +
|
||||||
|
1.j * np.random.randn(*dense_shape_b)).astype(dtype)
|
||||||
|
a_sm = sparse_csr_matrix_ops.CSRSparseMatrix(a_mats)
|
||||||
|
c_dense = math_ops.matmul(
|
||||||
|
a_mats,
|
||||||
|
b_mats,
|
||||||
|
transpose_a=transpose_a,
|
||||||
|
transpose_b=transpose_b,
|
||||||
|
adjoint_a=adjoint_a,
|
||||||
|
adjoint_b=adjoint_b)
|
||||||
|
c_sm_dense = sparse_csr_matrix_ops.matmul(
|
||||||
|
a_sm,
|
||||||
|
b_mats,
|
||||||
|
transpose_a=transpose_a,
|
||||||
|
transpose_b=transpose_b,
|
||||||
|
adjoint_a=adjoint_a,
|
||||||
|
adjoint_b=adjoint_b)
|
||||||
|
c_dense, c_sm_dense = self.evaluate([c_dense, c_sm_dense])
|
||||||
|
self.assertAllClose(c_dense, c_sm_dense)
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testSparseDense(self):
|
||||||
|
for (t_a, t_b, adj_a, adj_b) in itertools.product(*(([False, True],) * 4)):
|
||||||
|
if (t_a and adj_a) or (t_b and adj_b):
|
||||||
|
continue
|
||||||
|
self._testSparseDense(t_a, t_b, adj_a, adj_b)
|
||||||
|
|
||||||
|
def _testDenseSparse(self, transpose_a, transpose_b, adjoint_a, adjoint_b):
|
||||||
|
if not self._gpu_available:
|
||||||
|
return
|
||||||
|
|
||||||
|
sparsify = lambda m: m * (m > 0)
|
||||||
|
dense_shape_a = [5, 13, 7] if transpose_a or adjoint_a else [5, 7, 13]
|
||||||
|
dense_shape_b = [5, 15, 13] if transpose_b or adjoint_b else [5, 13, 15]
|
||||||
|
for dtype in np.float32, np.complex64:
|
||||||
|
a_mats = (np.random.randn(*dense_shape_a) +
|
||||||
|
1.j * np.random.randn(*dense_shape_a)).astype(dtype)
|
||||||
|
b_mats = sparsify((np.random.randn(*dense_shape_b) +
|
||||||
|
1.j * np.random.randn(*dense_shape_b))).astype(dtype)
|
||||||
|
b_sm = sparse_csr_matrix_ops.CSRSparseMatrix(b_mats)
|
||||||
|
c_dense = math_ops.matmul(
|
||||||
|
a_mats,
|
||||||
|
b_mats,
|
||||||
|
transpose_a=transpose_a,
|
||||||
|
transpose_b=transpose_b,
|
||||||
|
adjoint_a=adjoint_a,
|
||||||
|
adjoint_b=adjoint_b)
|
||||||
|
c_sm_dense = sparse_csr_matrix_ops.matmul(
|
||||||
|
a_mats,
|
||||||
|
b_sm,
|
||||||
|
transpose_a=transpose_a,
|
||||||
|
transpose_b=transpose_b,
|
||||||
|
adjoint_a=adjoint_a,
|
||||||
|
adjoint_b=adjoint_b)
|
||||||
|
c_dense, c_sm_dense = self.evaluate([c_dense, c_sm_dense])
|
||||||
|
self.assertAllClose(c_dense, c_sm_dense)
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testDenseSparse(self):
|
||||||
|
for (t_a, t_b, adj_a, adj_b) in itertools.product(*(([False, True],) * 4)):
|
||||||
|
if (t_a and adj_a) or (t_b and adj_b):
|
||||||
|
continue
|
||||||
|
self._testDenseSparse(t_a, t_b, adj_a, adj_b)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test.main()
|
@ -0,0 +1,138 @@
|
|||||||
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""CSR sparse matrix tests."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import itertools
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import gradient_checker
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
|
from tensorflow.python.ops.linalg.sparse import sparse_csr_matrix_grad # pylint: disable=unused-import
|
||||||
|
from tensorflow.python.ops.linalg.sparse import sparse_csr_matrix_ops
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
from tensorflow.python.platform import tf_logging
|
||||||
|
|
||||||
|
|
||||||
|
def dense_to_csr_sparse_matrix(dense):
|
||||||
|
dense_t = ops.convert_to_tensor(dense)
|
||||||
|
locs = array_ops.stop_gradient(array_ops.where(math_ops.abs(dense_t) > 0))
|
||||||
|
return sparse_csr_matrix_ops.dense_to_csr_sparse_matrix(dense_t, locs)
|
||||||
|
|
||||||
|
|
||||||
|
def _add_test(test, op_name, testcase_name, fn): # pylint: disable=redefined-outer-name
|
||||||
|
if fn is None:
|
||||||
|
return
|
||||||
|
test_name = "_".join(["test", op_name, testcase_name])
|
||||||
|
if hasattr(test, test_name):
|
||||||
|
raise RuntimeError("Test %s defined more than once" % test_name)
|
||||||
|
setattr(test, test_name, fn)
|
||||||
|
|
||||||
|
|
||||||
|
class CSRSparseMatrixDenseMatMulGradTest(test.TestCase):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
super(CSRSparseMatrixDenseMatMulGradTest, cls).setUpClass()
|
||||||
|
cls._gpu_available = test_util.is_gpu_available()
|
||||||
|
|
||||||
|
# TODO(penporn): Make these tests runnable on eager mode.
|
||||||
|
# (tf.gradients and gradient_checker only run in graph mode.)
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
|
def _testLargeBatchSparseMatrixMatMulGrad(self, datatype, transpose_a,
|
||||||
|
transpose_b, adjoint_a, adjoint_b,
|
||||||
|
transpose_output, conjugate_output):
|
||||||
|
if not self._gpu_available:
|
||||||
|
return
|
||||||
|
|
||||||
|
sparsify = lambda m: m * (m > 0)
|
||||||
|
a_mats_val = sparsify(
|
||||||
|
np.random.randn(3, 5, 11) +
|
||||||
|
1.j * np.random.randn(3, 5, 11)).astype(datatype)
|
||||||
|
if transpose_a or adjoint_a:
|
||||||
|
a_mats_val = np.transpose(a_mats_val, (0, 2, 1))
|
||||||
|
if adjoint_a:
|
||||||
|
a_mats_val = np.conj(a_mats_val)
|
||||||
|
b_mats_val = (np.random.randn(3, 11, 13) +
|
||||||
|
1.j * np.random.randn(3, 11, 13)).astype(datatype)
|
||||||
|
if transpose_b or adjoint_b:
|
||||||
|
b_mats_val = np.transpose(b_mats_val, (0, 2, 1))
|
||||||
|
if adjoint_b:
|
||||||
|
b_mats_val = np.conj(b_mats_val)
|
||||||
|
with self.test_session(use_gpu=True):
|
||||||
|
a_mats = ops.convert_to_tensor(a_mats_val, dtype=datatype)
|
||||||
|
b_mats = ops.convert_to_tensor(b_mats_val, dtype=datatype)
|
||||||
|
a_sm = dense_to_csr_sparse_matrix(a_mats)
|
||||||
|
c_mats = sparse_csr_matrix_ops.sparse_matrix_mat_mul(
|
||||||
|
a_sm,
|
||||||
|
b_mats,
|
||||||
|
transpose_a=transpose_a,
|
||||||
|
transpose_b=transpose_b,
|
||||||
|
adjoint_a=adjoint_a,
|
||||||
|
adjoint_b=adjoint_b,
|
||||||
|
transpose_output=transpose_output,
|
||||||
|
conjugate_output=conjugate_output)
|
||||||
|
for [ten, val, nn] in [[a_mats, a_mats_val, "a"],
|
||||||
|
[b_mats, b_mats_val, "b"]]:
|
||||||
|
tf_logging.info("Testing gradients for %s" % nn)
|
||||||
|
theoretical, numerical = gradient_checker.compute_gradient(
|
||||||
|
ten,
|
||||||
|
ten.get_shape().as_list(),
|
||||||
|
c_mats,
|
||||||
|
c_mats.get_shape().as_list(),
|
||||||
|
x_init_value=val,
|
||||||
|
delta=1e-3)
|
||||||
|
self.assertAllClose(theoretical, numerical, atol=1e-3, rtol=1e-3)
|
||||||
|
|
||||||
|
|
||||||
|
# These tests are refactored from sparse_csr_matrix_grad_test to keep its size
|
||||||
|
# "medium".
|
||||||
|
for dtype in (np.float32, np.complex64):
|
||||||
|
for (t_a, t_b, adj_a, adj_b, t_out,
|
||||||
|
conj_out) in itertools.product(*(([False, True],) * 6)):
|
||||||
|
|
||||||
|
def create_mat_mul_test_fn(dtype_, t_a_, t_b_, adj_a_, adj_b_, t_out_,
|
||||||
|
conj_out_):
|
||||||
|
# Skip invalid cases.
|
||||||
|
if (t_a_ and adj_a_) or (t_b_ and adj_b_):
|
||||||
|
return
|
||||||
|
# Skip cases where we conjugate real matrices.
|
||||||
|
if dtype_ == np.float32 and (adj_a_ or adj_b_ or conj_out_):
|
||||||
|
return
|
||||||
|
|
||||||
|
def test_fn(self):
|
||||||
|
self._testLargeBatchSparseMatrixMatMulGrad(dtype_, t_a_, t_b_, adj_a_,
|
||||||
|
adj_b_, t_out_, conj_out_)
|
||||||
|
|
||||||
|
return test_fn
|
||||||
|
|
||||||
|
name = (
|
||||||
|
"_testLargeBatchSparseMatrixMatMulGrad_dtype_%s_t_a_%s_t_b_%s_adj_a_%s_"
|
||||||
|
"adj_b_%s_t_out_%s_conj_out_%s" %
|
||||||
|
(dtype.__name__, t_a, t_b, adj_a, adj_b, t_out, conj_out))
|
||||||
|
|
||||||
|
_add_test(
|
||||||
|
CSRSparseMatrixDenseMatMulGradTest, "CSRSparseMatrixGradTest", name,
|
||||||
|
create_mat_mul_test_fn(dtype, t_a, t_b, adj_a, adj_b, t_out, conj_out))
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test.main()
|
119
tensorflow/python/kernel_tests/sparse_csr_matrix_grad_test.py
Normal file
119
tensorflow/python/kernel_tests/sparse_csr_matrix_grad_test.py
Normal file
@ -0,0 +1,119 @@
|
|||||||
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""CSR sparse matrix tests."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import gradients_impl
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
|
from tensorflow.python.ops.linalg.sparse import sparse_csr_matrix_grad # pylint: disable=unused-import
|
||||||
|
from tensorflow.python.ops.linalg.sparse import sparse_csr_matrix_ops
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
from tensorflow.python.platform import tf_logging
|
||||||
|
|
||||||
|
|
||||||
|
def dense_to_csr_sparse_matrix(dense):
|
||||||
|
dense_t = ops.convert_to_tensor(dense)
|
||||||
|
locs = array_ops.stop_gradient(array_ops.where(math_ops.abs(dense_t) > 0))
|
||||||
|
return sparse_csr_matrix_ops.dense_to_csr_sparse_matrix(dense_t, locs)
|
||||||
|
|
||||||
|
|
||||||
|
def _add_test(test, op_name, testcase_name, fn): # pylint: disable=redefined-outer-name
|
||||||
|
if fn is None:
|
||||||
|
return
|
||||||
|
test_name = "_".join(["test", op_name, testcase_name])
|
||||||
|
if hasattr(test, test_name):
|
||||||
|
raise RuntimeError("Test %s defined more than once" % test_name)
|
||||||
|
setattr(test, test_name, fn)
|
||||||
|
|
||||||
|
|
||||||
|
class CSRSparseMatrixGradTest(test.TestCase):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
super(CSRSparseMatrixGradTest, cls).setUpClass()
|
||||||
|
cls._gpu_available = test_util.is_gpu_available()
|
||||||
|
|
||||||
|
# TODO(penporn): Make these tests runnable on eager mode.
|
||||||
|
# (tf.gradients and gradient_checker only run in graph mode.)
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
|
def testLargeBatchConversionGrad(self):
|
||||||
|
if not self._gpu_available:
|
||||||
|
return
|
||||||
|
|
||||||
|
sparsify = lambda m: m * (m > 0)
|
||||||
|
for dense_shape in ([53, 65, 127], [127, 65]):
|
||||||
|
mats_val = sparsify(np.random.randn(*dense_shape))
|
||||||
|
with self.test_session(use_gpu=True) as sess:
|
||||||
|
mats = math_ops.cast(mats_val, dtype=dtypes.float32)
|
||||||
|
sparse_mats = dense_to_csr_sparse_matrix(mats)
|
||||||
|
dense_mats = sparse_csr_matrix_ops.csr_sparse_matrix_to_dense(
|
||||||
|
sparse_mats, dtypes.float32)
|
||||||
|
grad_vals = np.random.randn(*dense_shape).astype(np.float32)
|
||||||
|
grad_out = gradients_impl.gradients([dense_mats], [mats],
|
||||||
|
[grad_vals])[0]
|
||||||
|
self.assertEqual(grad_out.dtype, dtypes.float32)
|
||||||
|
self.assertEqual(grad_out.shape, dense_shape)
|
||||||
|
grad_out_value = sess.run(grad_out)
|
||||||
|
tf_logging.info("testLargeBatchConversionGrad: Testing shape %s" %
|
||||||
|
dense_shape)
|
||||||
|
self.assertAllEqual(grad_vals, grad_out_value)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
|
def testLargeBatchSparseMatrixAddGrad(self):
|
||||||
|
if not self._gpu_available:
|
||||||
|
return
|
||||||
|
|
||||||
|
sparsify = lambda m: m * (m > 0)
|
||||||
|
for dense_shape in ([53, 65, 127], [127, 65]):
|
||||||
|
a_mats_val = sparsify(np.random.randn(*dense_shape))
|
||||||
|
b_mats_val = sparsify(np.random.randn(*dense_shape))
|
||||||
|
alpha = np.float32(0.5)
|
||||||
|
beta = np.float32(-1.5)
|
||||||
|
grad_vals = np.random.randn(*dense_shape).astype(np.float32)
|
||||||
|
expected_a_grad = alpha * grad_vals
|
||||||
|
expected_b_grad = beta * grad_vals
|
||||||
|
with self.test_session(use_gpu=True) as sess:
|
||||||
|
a_mats = math_ops.cast(a_mats_val, dtype=dtypes.float32)
|
||||||
|
b_mats = math_ops.cast(b_mats_val, dtype=dtypes.float32)
|
||||||
|
a_sm = dense_to_csr_sparse_matrix(a_mats)
|
||||||
|
b_sm = dense_to_csr_sparse_matrix(b_mats)
|
||||||
|
c_sm = sparse_csr_matrix_ops.sparse_matrix_add(
|
||||||
|
a_sm, b_sm, alpha=alpha, beta=beta)
|
||||||
|
c_dense = sparse_csr_matrix_ops.csr_sparse_matrix_to_dense(
|
||||||
|
c_sm, dtypes.float32)
|
||||||
|
a_grad, b_grad = gradients_impl.gradients([c_dense], [a_mats, b_mats],
|
||||||
|
[grad_vals])
|
||||||
|
self.assertEqual(a_grad.dtype, dtypes.float32)
|
||||||
|
self.assertEqual(b_grad.dtype, dtypes.float32)
|
||||||
|
self.assertEqual(a_grad.shape, dense_shape)
|
||||||
|
self.assertEqual(b_grad.shape, dense_shape)
|
||||||
|
a_grad_value, b_grad_value = sess.run((a_grad, b_grad))
|
||||||
|
tf_logging.info("testLargeBatchConversionGrad: Testing shape %s" %
|
||||||
|
dense_shape)
|
||||||
|
self.assertAllEqual(expected_a_grad, a_grad_value)
|
||||||
|
self.assertAllEqual(expected_b_grad, b_grad_value)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test.main()
|
1511
tensorflow/python/kernel_tests/sparse_csr_matrix_ops_test.py
Normal file
1511
tensorflow/python/kernel_tests/sparse_csr_matrix_ops_test.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,137 @@
|
|||||||
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""CSR sparse matrix tests."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import itertools
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import gradient_checker
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
|
from tensorflow.python.ops.linalg.sparse import sparse_csr_matrix_grad # pylint: disable=unused-import
|
||||||
|
from tensorflow.python.ops.linalg.sparse import sparse_csr_matrix_ops
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
from tensorflow.python.platform import tf_logging
|
||||||
|
|
||||||
|
|
||||||
|
def dense_to_csr_sparse_matrix(dense):
|
||||||
|
dense_t = ops.convert_to_tensor(dense)
|
||||||
|
locs = array_ops.stop_gradient(array_ops.where(math_ops.abs(dense_t) > 0))
|
||||||
|
return sparse_csr_matrix_ops.dense_to_csr_sparse_matrix(dense_t, locs)
|
||||||
|
|
||||||
|
|
||||||
|
def _add_test(test, op_name, testcase_name, fn): # pylint: disable=redefined-outer-name
|
||||||
|
if fn is None:
|
||||||
|
return
|
||||||
|
test_name = "_".join(["test", op_name, testcase_name])
|
||||||
|
if hasattr(test, test_name):
|
||||||
|
raise RuntimeError("Test %s defined more than once" % test_name)
|
||||||
|
setattr(test, test_name, fn)
|
||||||
|
|
||||||
|
|
||||||
|
class CSRSparseMatrixGradTest(test.TestCase):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
super(CSRSparseMatrixGradTest, cls).setUpClass()
|
||||||
|
cls._gpu_available = test_util.is_gpu_available()
|
||||||
|
|
||||||
|
# TODO(penporn): Make these tests runnable on eager mode.
|
||||||
|
# (tf.gradients and gradient_checker only run in graph mode.)
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
|
def _testLargeBatchSparseMatrixSparseMatMulGrad(self, datatype, transpose_a,
|
||||||
|
transpose_b, adjoint_a,
|
||||||
|
adjoint_b):
|
||||||
|
if not self._gpu_available:
|
||||||
|
return
|
||||||
|
|
||||||
|
sparsify = lambda m: m * (m > 0)
|
||||||
|
a_mats_val = sparsify(
|
||||||
|
np.random.randn(3, 5, 11) +
|
||||||
|
1.j * np.random.randn(3, 5, 11)).astype(datatype)
|
||||||
|
if transpose_a or adjoint_a:
|
||||||
|
a_mats_val = np.transpose(a_mats_val, (0, 2, 1))
|
||||||
|
if adjoint_a:
|
||||||
|
a_mats_val = np.conj(a_mats_val)
|
||||||
|
b_mats_val = sparsify(
|
||||||
|
np.random.randn(3, 11, 13) +
|
||||||
|
1.j * np.random.randn(3, 11, 13)).astype(datatype)
|
||||||
|
if transpose_b or adjoint_b:
|
||||||
|
b_mats_val = np.transpose(b_mats_val, (0, 2, 1))
|
||||||
|
if adjoint_b:
|
||||||
|
b_mats_val = np.conj(b_mats_val)
|
||||||
|
with self.test_session(use_gpu=True):
|
||||||
|
a_mats = ops.convert_to_tensor(a_mats_val, dtype=datatype)
|
||||||
|
b_mats = ops.convert_to_tensor(b_mats_val, dtype=datatype)
|
||||||
|
a_sm = dense_to_csr_sparse_matrix(a_mats)
|
||||||
|
b_sm = dense_to_csr_sparse_matrix(b_mats)
|
||||||
|
c_sm = sparse_csr_matrix_ops.sparse_matrix_sparse_mat_mul(
|
||||||
|
a_sm,
|
||||||
|
b_sm,
|
||||||
|
transpose_a=transpose_a,
|
||||||
|
transpose_b=transpose_b,
|
||||||
|
adjoint_a=adjoint_a,
|
||||||
|
adjoint_b=adjoint_b,
|
||||||
|
type=datatype)
|
||||||
|
c_dense = sparse_csr_matrix_ops.csr_sparse_matrix_to_dense(
|
||||||
|
c_sm, type=datatype)
|
||||||
|
for ten, val, nn in [[a_mats, a_mats_val, "a"], [b_mats, b_mats_val,
|
||||||
|
"b"]]:
|
||||||
|
tf_logging.info("Testing gradients for %s" % nn)
|
||||||
|
theoretical, numerical = gradient_checker.compute_gradient(
|
||||||
|
ten,
|
||||||
|
ten.get_shape().as_list(),
|
||||||
|
c_dense,
|
||||||
|
c_dense.get_shape().as_list(),
|
||||||
|
x_init_value=val,
|
||||||
|
delta=1e-3)
|
||||||
|
self.assertAllClose(theoretical, numerical, atol=1e-3, rtol=1e-3)
|
||||||
|
|
||||||
|
|
||||||
|
# These tests are refactored from sparse_csr_matrix_grad_test to keep its size
|
||||||
|
# "medium".
|
||||||
|
for dtype in (np.float32, np.complex64):
|
||||||
|
for (t_a, t_b, adj_a, adj_b) in itertools.product(*(([False, True],) * 4)):
|
||||||
|
|
||||||
|
def create_sparse_mat_mul_test_fn(dtype_, t_a_, t_b_, adj_a_, adj_b_):
|
||||||
|
# Skip invalid cases.
|
||||||
|
if (t_a_ and adj_a_) or (t_b_ and adj_b_):
|
||||||
|
return
|
||||||
|
# Skip cases where we conjugate real matrices.
|
||||||
|
if dtype_ == np.float32 and (adj_a_ or adj_b_):
|
||||||
|
return
|
||||||
|
|
||||||
|
def test_fn(self):
|
||||||
|
self._testLargeBatchSparseMatrixSparseMatMulGrad(
|
||||||
|
dtype_, t_a_, t_b_, adj_a_, adj_b_)
|
||||||
|
|
||||||
|
return test_fn
|
||||||
|
|
||||||
|
name = (
|
||||||
|
"_testLargeBatchSparseMatrixSparseMatMulGrad_dtype_%s_t_a_%s_t_b_%s_"
|
||||||
|
"adj_a_%s_adj_b_%s" % (dtype.__name__, t_a, t_b, adj_a, adj_b))
|
||||||
|
|
||||||
|
_add_test(CSRSparseMatrixGradTest, "CSRSparseMatrixSparseGradTest", name,
|
||||||
|
create_sparse_mat_mul_test_fn(dtype, t_a, t_b, adj_a, adj_b))
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test.main()
|
35
tensorflow/python/ops/linalg/sparse/BUILD
Normal file
35
tensorflow/python/ops/linalg/sparse/BUILD
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
# Description: Sparse CSR support for TensorFlow.
|
||||||
|
load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py")
|
||||||
|
|
||||||
|
package(
|
||||||
|
default_visibility = ["//tensorflow:internal"],
|
||||||
|
licenses = ["notice"], # Apache 2.0
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_gen_op_wrapper_py(
|
||||||
|
name = "gen_sparse_csr_matrix_ops",
|
||||||
|
out = "gen_sparse_csr_matrix_ops.py",
|
||||||
|
api_def_srcs = ["//tensorflow/core/api_def:base_api_def"],
|
||||||
|
visibility = [
|
||||||
|
"//learning/brain/python/ops:__pkg__",
|
||||||
|
"//tensorflow/compiler/tests:__pkg__",
|
||||||
|
"//tensorflow/contrib/quantization:__pkg__",
|
||||||
|
"//tensorflow/python/kernel_tests:__pkg__",
|
||||||
|
],
|
||||||
|
deps = ["//tensorflow/core:sparse_csr_matrix_ops_op_lib"],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "sparse",
|
||||||
|
srcs = [
|
||||||
|
"__init__.py",
|
||||||
|
"sparse.py",
|
||||||
|
"sparse_csr_matrix_grad.py",
|
||||||
|
"sparse_csr_matrix_ops.py",
|
||||||
|
],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":gen_sparse_csr_matrix_ops",
|
||||||
|
"//third_party/py/numpy",
|
||||||
|
],
|
||||||
|
)
|
0
tensorflow/python/ops/linalg/sparse/__init__.py
Normal file
0
tensorflow/python/ops/linalg/sparse/__init__.py
Normal file
25
tensorflow/python/ops/linalg/sparse/sparse.py
Normal file
25
tensorflow/python/ops/linalg/sparse/sparse.py
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Public API for tf.linalg.sparse namespace."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
# go/tf-wildcard-import
|
||||||
|
# pylint: disable=wildcard-import
|
||||||
|
from tensorflow.python.ops.linalg.sparse.sparse_csr_matrix_grad import *
|
||||||
|
from tensorflow.python.ops.linalg.sparse.sparse_csr_matrix_ops import *
|
||||||
|
# pylint: enable=wildcard-import
|
233
tensorflow/python/ops/linalg/sparse/sparse_csr_matrix_grad.py
Normal file
233
tensorflow/python/ops/linalg/sparse/sparse_csr_matrix_grad.py
Normal file
@ -0,0 +1,233 @@
|
|||||||
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""CSR Sparse Matrix Gradients."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
|
from tensorflow.python.ops.linalg.sparse import sparse_csr_matrix_ops
|
||||||
|
|
||||||
|
|
||||||
|
@ops.RegisterGradient("DenseToCSRSparseMatrix")
|
||||||
|
def _DenseToCSRSparseMatrixGrad(op, grad):
|
||||||
|
"""Gradient for dense_to_csr_sparse_matrix op."""
|
||||||
|
grad_values = (
|
||||||
|
sparse_csr_matrix_ops.csr_sparse_matrix_to_dense(
|
||||||
|
grad, type=op.get_attr("T")))
|
||||||
|
# inputs to fw op were: params, indices.
|
||||||
|
return (grad_values, None)
|
||||||
|
|
||||||
|
|
||||||
|
@ops.RegisterGradient("CSRSparseMatrixToDense")
|
||||||
|
def _CSRSparseMatrixToDenseGrad(op, grad):
|
||||||
|
"""Gradient for csr_sparse_matrix_to_dense op."""
|
||||||
|
del op # Unused
|
||||||
|
return sparse_csr_matrix_ops.dense_to_csr_sparse_matrix(
|
||||||
|
grad, array_ops.stop_gradient(array_ops.where(math_ops.abs(grad) > 0)))
|
||||||
|
|
||||||
|
|
||||||
|
ops.NotDifferentiable("SparseMatrixNNZ")
|
||||||
|
|
||||||
|
ops.NotDifferentiable("SparseMatrixZeros")
|
||||||
|
|
||||||
|
|
||||||
|
@ops.RegisterGradient("SparseMatrixAdd")
|
||||||
|
def _SparseMatrixAddGrad(op, grad):
|
||||||
|
"""Gradient for sparse_matrix_add op."""
|
||||||
|
# input to sparse_matrix_add is (a, b, alpha, beta)
|
||||||
|
# with a, b CSR and alpha beta scalars.
|
||||||
|
# output is: alpha * a + beta * b
|
||||||
|
|
||||||
|
# d(a*A + b*B)/dA . grad = a * grad
|
||||||
|
|
||||||
|
# May have gotten the transposes wrong below.
|
||||||
|
# d(a*A + b*B)/da . grad = tr(A' . grad)
|
||||||
|
|
||||||
|
# For now, only implement gradients w.r.t. A and B.
|
||||||
|
# TODO(ebrevdo): Implement reduce_sum for SparseMatrix so that we
|
||||||
|
# can implement gradients w.r.t. a and b.
|
||||||
|
(_, _, alpha, beta) = op.inputs
|
||||||
|
return (sparse_csr_matrix_ops.sparse_matrix_mul(grad, alpha),
|
||||||
|
sparse_csr_matrix_ops.sparse_matrix_mul(grad, beta), None, None)
|
||||||
|
|
||||||
|
|
||||||
|
@ops.RegisterGradient("SparseMatrixTranspose")
|
||||||
|
def _SparseMatrixTransposeGrad(op, grad):
|
||||||
|
"""Gradient for sparse_matrix_transpose op."""
|
||||||
|
return sparse_csr_matrix_ops.sparse_matrix_transpose(
|
||||||
|
grad, type=op.get_attr("type"), conjugate=op.get_attr("conjugate"))
|
||||||
|
|
||||||
|
|
||||||
|
@ops.RegisterGradient("SparseMatrixSoftmax")
|
||||||
|
def _SparseMatrixSoftmaxGrad(op, grad_softmax):
|
||||||
|
"""Gradient for sparse_matrix_softmax op."""
|
||||||
|
softmax = op.outputs[0]
|
||||||
|
return sparse_csr_matrix_ops.sparse_matrix_softmax_grad(
|
||||||
|
softmax, grad_softmax, type=op.get_attr("type"))
|
||||||
|
|
||||||
|
|
||||||
|
@ops.RegisterGradient("SparseMatrixMatMul")
|
||||||
|
def _SparseMatrixMatMulGrad(op, grad):
|
||||||
|
"""Gradient for sparse_matrix_mat_mul op."""
|
||||||
|
# input to sparse_matrix_mat_mul is (A, B) with CSR A and dense B.
|
||||||
|
# Output is dense:
|
||||||
|
# C = opA(A) . opB(B) if transpose_output = false
|
||||||
|
# C = (opA(A) . opB(B))' = opB(B)' . opA(A)' if transpose_output = true.
|
||||||
|
# where opA = transpose if transpose_a = True else identity
|
||||||
|
# and opB = transpose if transpose_b = True else identity
|
||||||
|
|
||||||
|
t_a = op.get_attr("transpose_a")
|
||||||
|
t_b = op.get_attr("transpose_b")
|
||||||
|
adj_a = op.get_attr("adjoint_a")
|
||||||
|
adj_b = op.get_attr("adjoint_b")
|
||||||
|
transpose_output = op.get_attr("transpose_output")
|
||||||
|
conjugate_output = op.get_attr("conjugate_output")
|
||||||
|
a = op.inputs[0] # sparse matrix
|
||||||
|
b = op.inputs[1] # dense matrix
|
||||||
|
conj = math_ops.conj
|
||||||
|
sparse_matmul = sparse_csr_matrix_ops.sparse_matrix_mat_mul
|
||||||
|
matmul = math_ops.matmul
|
||||||
|
|
||||||
|
if conjugate_output:
|
||||||
|
grad = conj(grad)
|
||||||
|
if not transpose_output:
|
||||||
|
# C = opA(A) . opB(B)
|
||||||
|
if not adj_a and not adj_b:
|
||||||
|
a = conj(a)
|
||||||
|
b = conj(b)
|
||||||
|
if not t_a:
|
||||||
|
grad_a_dense = matmul(grad, b, transpose_b=not t_b)
|
||||||
|
else:
|
||||||
|
grad_a_dense = matmul(b, grad, transpose_a=t_b, transpose_b=True)
|
||||||
|
grad_b = sparse_matmul(a, grad, transpose_a=not t_a, transpose_output=t_b)
|
||||||
|
elif not t_a and not t_b:
|
||||||
|
if not adj_a:
|
||||||
|
grad_a_dense = matmul(grad, b, adjoint_b=not adj_b)
|
||||||
|
else:
|
||||||
|
grad_a_dense = matmul(b, grad, adjoint_a=adj_b, adjoint_b=True)
|
||||||
|
grad_b = sparse_matmul(
|
||||||
|
a,
|
||||||
|
grad,
|
||||||
|
adjoint_a=not adj_a,
|
||||||
|
transpose_output=adj_b,
|
||||||
|
conjugate_output=adj_b)
|
||||||
|
elif adj_a and t_b:
|
||||||
|
grad_a_dense = matmul(b, grad, transpose_a=True, adjoint_b=True)
|
||||||
|
grad_b = sparse_matmul(a, grad, transpose_output=True)
|
||||||
|
elif t_a and adj_b:
|
||||||
|
grad_a_dense = matmul(b, grad, transpose_a=True, transpose_b=True)
|
||||||
|
grad_b = sparse_matmul(
|
||||||
|
conj(a), grad, transpose_output=True, conjugate_output=True)
|
||||||
|
else:
|
||||||
|
# C = (opA(A) . opB(B))' = opB(B)' . opA(A)'
|
||||||
|
if not adj_a and not adj_b:
|
||||||
|
a = conj(a)
|
||||||
|
b = conj(b)
|
||||||
|
if not t_a:
|
||||||
|
grad_a_dense = matmul(grad, b, transpose_a=True, transpose_b=not t_b)
|
||||||
|
else:
|
||||||
|
grad_a_dense = matmul(b, grad, transpose_a=t_b)
|
||||||
|
grad_b = sparse_matmul(
|
||||||
|
a, grad, transpose_a=not t_a, transpose_b=True, transpose_output=t_b)
|
||||||
|
elif not t_a and not t_b:
|
||||||
|
if not adj_a:
|
||||||
|
grad_a_dense = matmul(grad, b, transpose_a=True, adjoint_b=not adj_b)
|
||||||
|
else:
|
||||||
|
grad_a_dense = matmul(b, conj(grad), adjoint_a=adj_b)
|
||||||
|
grad_b = sparse_matmul(
|
||||||
|
a,
|
||||||
|
grad,
|
||||||
|
adjoint_a=not adj_a,
|
||||||
|
transpose_b=True,
|
||||||
|
transpose_output=adj_b,
|
||||||
|
conjugate_output=adj_b)
|
||||||
|
elif adj_a and t_b:
|
||||||
|
grad_a_dense = matmul(b, conj(grad), transpose_a=True)
|
||||||
|
grad_b = sparse_matmul(a, grad, transpose_b=True, transpose_output=True)
|
||||||
|
elif t_a and adj_b:
|
||||||
|
grad_a_dense = matmul(b, grad, transpose_a=True)
|
||||||
|
grad_b = sparse_matmul(a, grad, adjoint_b=True, transpose_output=True)
|
||||||
|
|
||||||
|
grad_a = sparse_csr_matrix_ops.dense_to_csr_sparse_matrix(
|
||||||
|
grad_a_dense, array_ops.where(math_ops.abs(grad_a_dense) > 0))
|
||||||
|
return (grad_a, grad_b)
|
||||||
|
|
||||||
|
|
||||||
|
@ops.RegisterGradient("SparseMatrixSparseMatMul")
|
||||||
|
def _SparseMatrixSparseMatMulGrad(op, grad):
|
||||||
|
"""Gradient for sparse_matrix_sparse_mat_mul op."""
|
||||||
|
t_a = op.get_attr("transpose_a")
|
||||||
|
t_b = op.get_attr("transpose_b")
|
||||||
|
adj_a = op.get_attr("adjoint_a")
|
||||||
|
adj_b = op.get_attr("adjoint_b")
|
||||||
|
dtype = op.get_attr("type")
|
||||||
|
|
||||||
|
# input to sparse_matrix_sparse_mat_mul is (A, B) with CSR A and B.
|
||||||
|
# Output is CSR:
|
||||||
|
# C = opA(A) . opB(B)
|
||||||
|
# where opA = transpose if transpose_a = True else identity
|
||||||
|
# and opB = transpose if transpose_b = True else identity
|
||||||
|
a = op.inputs[0]
|
||||||
|
b = op.inputs[1]
|
||||||
|
conj = math_ops.conj
|
||||||
|
matmul = sparse_csr_matrix_ops.sparse_matrix_sparse_mat_mul
|
||||||
|
if not t_a and not t_b:
|
||||||
|
if not adj_a:
|
||||||
|
if not adj_b:
|
||||||
|
grad_a = matmul(grad, b, adjoint_b=True, type=dtype)
|
||||||
|
grad_b = matmul(a, grad, adjoint_a=True, type=dtype)
|
||||||
|
else:
|
||||||
|
grad_a = matmul(grad, b, type=dtype)
|
||||||
|
grad_b = matmul(grad, a, adjoint_a=True, type=dtype)
|
||||||
|
else:
|
||||||
|
if not adj_b:
|
||||||
|
grad_a = matmul(b, grad, adjoint_b=True, type=dtype)
|
||||||
|
grad_b = matmul(a, grad, type=dtype)
|
||||||
|
else:
|
||||||
|
grad_a = matmul(b, grad, adjoint_a=True, adjoint_b=True, type=dtype)
|
||||||
|
grad_b = matmul(grad, a, adjoint_a=True, adjoint_b=True, type=dtype)
|
||||||
|
elif not adj_a and not adj_b:
|
||||||
|
if not t_a and t_b:
|
||||||
|
grad_a = matmul(grad, conj(b), type=dtype)
|
||||||
|
grad_b = matmul(grad, conj(a), transpose_a=True, type=dtype)
|
||||||
|
elif t_a and not t_b:
|
||||||
|
grad_a = matmul(conj(b), grad, transpose_b=True, type=dtype)
|
||||||
|
grad_b = matmul(conj(a), grad, type=dtype)
|
||||||
|
else:
|
||||||
|
grad_a = matmul(b, grad, adjoint_a=True, transpose_b=True, type=dtype)
|
||||||
|
grad_b = matmul(grad, a, transpose_a=True, adjoint_b=True, type=dtype)
|
||||||
|
elif adj_a and t_b:
|
||||||
|
grad_a = matmul(b, grad, transpose_a=True, adjoint_b=True, type=dtype)
|
||||||
|
grad_b = matmul(grad, a, transpose_a=True, transpose_b=True, type=dtype)
|
||||||
|
elif t_a and adj_b:
|
||||||
|
grad_a = matmul(b, grad, transpose_a=True, transpose_b=True, type=dtype)
|
||||||
|
grad_b = matmul(grad, a, adjoint_a=True, transpose_b=True, type=dtype)
|
||||||
|
|
||||||
|
return (grad_a, grad_b)
|
||||||
|
|
||||||
|
|
||||||
|
@ops.RegisterGradient("SparseMatrixMul")
|
||||||
|
def _SparseMatrixMulGrad(op, grad):
|
||||||
|
"""Gradient for sparse_matrix_mul op."""
|
||||||
|
# input to sparse_matrix_mul is (A, B) with CSR A and dense B.
|
||||||
|
# Output is CSR:
|
||||||
|
# C = A .* B
|
||||||
|
del op
|
||||||
|
del grad
|
||||||
|
raise NotImplementedError
|
378
tensorflow/python/ops/linalg/sparse/sparse_csr_matrix_ops.py
Normal file
378
tensorflow/python/ops/linalg/sparse/sparse_csr_matrix_ops.py
Normal file
@ -0,0 +1,378 @@
|
|||||||
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""CSR Sparse Matrix Operations."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import abc
|
||||||
|
import collections
|
||||||
|
|
||||||
|
import six
|
||||||
|
|
||||||
|
# pylint: disable=g-direct-tensorflow-import, wildcard-import
|
||||||
|
from tensorflow.python.eager import context
|
||||||
|
from tensorflow.python.framework import cpp_shape_inference_pb2
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import sparse_tensor
|
||||||
|
from tensorflow.python.framework import tensor_shape
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
|
from tensorflow.python.ops import resource_variable_ops
|
||||||
|
from tensorflow.python.ops.linalg.sparse import gen_sparse_csr_matrix_ops as sm_ops
|
||||||
|
from tensorflow.python.ops.linalg.sparse.gen_sparse_csr_matrix_ops import *
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"SparseMatrix",
|
||||||
|
"CSRSparseMatrix",
|
||||||
|
"matmul",
|
||||||
|
"dense_shape_and_type",
|
||||||
|
]
|
||||||
|
# pylint: disable=invalid-name
|
||||||
|
__all__ += [_x for _x in dir(sm_ops) if not _x.startswith("_")]
|
||||||
|
|
||||||
|
|
||||||
|
class DenseShapeAndType(
|
||||||
|
collections.namedtuple("DenseShapeAndType", ("shape", "dtype"))):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _get_handle_data(tensor):
|
||||||
|
return resource_variable_ops.get_eager_safe_handle_data(tensor)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_handle_data_proto(shape_proto, dtype_enum):
|
||||||
|
"""Create handle data based on shape and dtype protos."""
|
||||||
|
variant_shape_and_type_data = \
|
||||||
|
cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData()
|
||||||
|
variant_shape_and_type_data.is_set = True
|
||||||
|
# NOTE(ebrevdo): shape_and_type lacks append() in some versions of protobuf.
|
||||||
|
variant_shape_and_type_data.shape_and_type.extend([
|
||||||
|
cpp_shape_inference_pb2.CppShapeInferenceResult.HandleShapeAndType(
|
||||||
|
shape=shape_proto, dtype=dtype_enum)
|
||||||
|
])
|
||||||
|
return variant_shape_and_type_data
|
||||||
|
|
||||||
|
|
||||||
|
def _make_handle_data(tensor):
|
||||||
|
"""Create handle data based on tensor shape and dtype."""
|
||||||
|
return _create_handle_data_proto(tensor.shape.as_proto(),
|
||||||
|
tensor.dtype.as_datatype_enum)
|
||||||
|
|
||||||
|
|
||||||
|
def get_shape_and_type(matrix):
|
||||||
|
"""Return matrix's shape and type if available."""
|
||||||
|
handle_data = getattr(matrix, "_handle_data", None)
|
||||||
|
if handle_data is None:
|
||||||
|
return None
|
||||||
|
if len(handle_data.shape_and_type) != 1:
|
||||||
|
raise ValueError(
|
||||||
|
"shape_and_type array in _handle_data must have length one, but saw: %d"
|
||||||
|
% len(handle_data.shape_and_type))
|
||||||
|
return handle_data.shape_and_type[0]
|
||||||
|
|
||||||
|
|
||||||
|
def dense_shape_and_type(matrix):
|
||||||
|
"""Get dense shape and dtype of the tf.Tensor containing the matrix.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
matrix: A `tf.Tensor` of type `tf.variant` storing a sparse matrix.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An instance of `ShapeAndType` with properties `shape` (a `tf.TensorShape`)
|
||||||
|
and `dtype` (a `tf.DType`).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: if `matrix` is not a tensor or its dtype is not variant.
|
||||||
|
ValueError: if `matrix` lacks static handle data containing the dense
|
||||||
|
shape and dtype.
|
||||||
|
"""
|
||||||
|
if not isinstance(matrix, ops.Tensor):
|
||||||
|
raise TypeError("matrix should be a tensor, but saw: %s" % (matrix,))
|
||||||
|
if matrix.dtype != dtypes.variant:
|
||||||
|
raise TypeError(
|
||||||
|
"expected matrix to be type tf.variant, but saw: %s" % (matrix.dtype,))
|
||||||
|
handle_data = _get_handle_data(matrix)
|
||||||
|
if not handle_data or not handle_data.is_set:
|
||||||
|
raise ValueError("matrix has missing handle data: %s" % (matrix,))
|
||||||
|
if len(handle_data.shape_and_type) != 1:
|
||||||
|
raise ValueError("len(matrix.handle_data.shape_and_type) != 1: '%s'" %
|
||||||
|
(handle_data.shape_and_type,))
|
||||||
|
return DenseShapeAndType(
|
||||||
|
tensor_shape.TensorShape(handle_data.shape_and_type[0].shape),
|
||||||
|
dtypes.DType(handle_data.shape_and_type[0].dtype))
|
||||||
|
|
||||||
|
|
||||||
|
def matmul_shape_inference(a, b, c, transpose_a, transpose_b, adjoint_a,
|
||||||
|
adjoint_b):
|
||||||
|
"""Helper function for matmul to set the result matrix's handle data."""
|
||||||
|
c_handle = getattr(c, "_handle_data", None)
|
||||||
|
a_shape_and_type = get_shape_and_type(a)
|
||||||
|
b_shape_and_type = get_shape_and_type(b)
|
||||||
|
if (c_handle is None and a_shape_and_type is not None and
|
||||||
|
b_shape_and_type is not None):
|
||||||
|
|
||||||
|
transpose_a = transpose_a or adjoint_a
|
||||||
|
transpose_b = transpose_b or adjoint_b
|
||||||
|
|
||||||
|
a_shape = a_shape_and_type.shape
|
||||||
|
b_shape = b_shape_and_type.shape
|
||||||
|
rank = len(a_shape.dim)
|
||||||
|
|
||||||
|
# Creates the output shape.
|
||||||
|
c_rows = a_shape.dim[rank - (1 if transpose_a else 2)].size
|
||||||
|
c_cols = b_shape.dim[rank - (2 if transpose_b else 1)].size
|
||||||
|
c_shape = tensor_shape.TensorShape(a_shape)
|
||||||
|
c_shape = tensor_shape.TensorShape(c_shape[:rank - 2] + [c_rows, c_cols])
|
||||||
|
c_handle = _create_handle_data_proto(c_shape.as_proto(),
|
||||||
|
a_shape_and_type.dtype)
|
||||||
|
return c_handle
|
||||||
|
|
||||||
|
|
||||||
|
def matmul(a,
|
||||||
|
b,
|
||||||
|
transpose_a=False,
|
||||||
|
transpose_b=False,
|
||||||
|
adjoint_a=False,
|
||||||
|
adjoint_b=False,
|
||||||
|
name=None):
|
||||||
|
"""Perform a sparse matrix matmul between `a` and `b`.
|
||||||
|
|
||||||
|
Performs a contraction between `a` and `b` along the two innermost dimensions.
|
||||||
|
If both `a` and `b` are instances of `SparseMatrix`, returns a new instance
|
||||||
|
of `SparseMatrix` (same type as `a`). If one is not an instance of
|
||||||
|
`SparseMatrix`, returns a dense `Tensor`:
|
||||||
|
|
||||||
|
```
|
||||||
|
c = opA(a) . opB(b)
|
||||||
|
```
|
||||||
|
where `opA` (resp. `opB`) is the transpose or hermitian transpose depending
|
||||||
|
on the values of `transpose_a` (resp. `transpose_b`) and `adjoint_a`
|
||||||
|
(resp. `adjoint_b`).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
a: `Tensor` or `SparseMatrix`, having rank `2` or `3`.
|
||||||
|
b: `Tensor` or `SparseMatrix`, having rank `2` or `3`.
|
||||||
|
transpose_a: Python `bool`.
|
||||||
|
transpose_b: Python `bool`.
|
||||||
|
adjoint_a: Python `bool`.
|
||||||
|
adjoint_b: Python `bool`.
|
||||||
|
name: Optional name to use when creating ops.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `SparseMatrix` if both `a` and `b` are instances of `SparseMatrix`,
|
||||||
|
otherwise a dense `Tensor`.
|
||||||
|
"""
|
||||||
|
if not isinstance(a, SparseMatrix) and not isinstance(b, SparseMatrix):
|
||||||
|
return math_ops.matmul(
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
transpose_a=transpose_a,
|
||||||
|
transpose_b=transpose_b,
|
||||||
|
adjoint_a=adjoint_a,
|
||||||
|
adjoint_b=adjoint_b,
|
||||||
|
name=name)
|
||||||
|
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
a_matrix = a._matrix if isinstance(a, SparseMatrix) else a
|
||||||
|
b_matrix = b._matrix if isinstance(b, SparseMatrix) else b
|
||||||
|
with ops.name_scope(name, "SparseMatrixMatMul", [a_matrix, b_matrix]):
|
||||||
|
if isinstance(a, SparseMatrix) and isinstance(b, SparseMatrix):
|
||||||
|
if not (isinstance(a, type(b)) or isinstance(b, type(a))):
|
||||||
|
raise TypeError("SparseMatrix types don't inherit from each other: "
|
||||||
|
"%s and %s" % (type(a), type(b)))
|
||||||
|
c = sm_ops.sparse_matrix_sparse_mat_mul(
|
||||||
|
a_matrix,
|
||||||
|
b_matrix,
|
||||||
|
transpose_a=transpose_a,
|
||||||
|
transpose_b=transpose_b,
|
||||||
|
adjoint_a=adjoint_a,
|
||||||
|
adjoint_b=adjoint_b,
|
||||||
|
type=a.dtype)
|
||||||
|
|
||||||
|
# In eager mode, shape inference functions are not called, and the output
|
||||||
|
# shape is not set. We have to infer the output shape here.
|
||||||
|
# TODO(penporn): Set this from the C++ kernel instead.
|
||||||
|
c_handle = matmul_shape_inference(a_matrix, b_matrix, c, transpose_a,
|
||||||
|
transpose_b, adjoint_a, adjoint_b)
|
||||||
|
return a._from_matrix(c, handle_data=c_handle)
|
||||||
|
|
||||||
|
elif isinstance(a, SparseMatrix):
|
||||||
|
return sm_ops.sparse_matrix_mat_mul(
|
||||||
|
a_matrix,
|
||||||
|
b,
|
||||||
|
transpose_a=transpose_a,
|
||||||
|
transpose_b=transpose_b,
|
||||||
|
adjoint_a=adjoint_a,
|
||||||
|
adjoint_b=adjoint_b)
|
||||||
|
else:
|
||||||
|
# opA(A) . opB(B) = t(nopB(B) . nopA(A))
|
||||||
|
if not adjoint_a and not adjoint_b:
|
||||||
|
return sm_ops.sparse_matrix_mat_mul(
|
||||||
|
b_matrix,
|
||||||
|
a,
|
||||||
|
transpose_a=not transpose_b,
|
||||||
|
transpose_b=not transpose_a,
|
||||||
|
transpose_output=True)
|
||||||
|
elif not transpose_a and not transpose_b:
|
||||||
|
return sm_ops.sparse_matrix_mat_mul(
|
||||||
|
b_matrix,
|
||||||
|
a,
|
||||||
|
adjoint_a=not adjoint_b,
|
||||||
|
adjoint_b=not adjoint_a,
|
||||||
|
transpose_output=True,
|
||||||
|
conjugate_output=True)
|
||||||
|
else:
|
||||||
|
return sm_ops.sparse_matrix_mat_mul(
|
||||||
|
b_matrix,
|
||||||
|
math_ops.conj(a),
|
||||||
|
transpose_output=True,
|
||||||
|
conjugate_output=adjoint_b)
|
||||||
|
|
||||||
|
|
||||||
|
class SparseMatrix(six.with_metaclass(abc.ABCMeta)):
|
||||||
|
"""Abstract class for sparse matrix types."""
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def __init__(self):
|
||||||
|
self._eager_mode = context.executing_eagerly()
|
||||||
|
|
||||||
|
@abc.abstractproperty
|
||||||
|
def _matrix(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def _from_matrix(self, matrix, handle_data=None):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def to_dense(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def to_sparse_tensor(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
def graph(self):
|
||||||
|
return self._matrix.graph
|
||||||
|
|
||||||
|
@property
|
||||||
|
def shape(self):
|
||||||
|
return dense_shape_and_type(self._matrix).shape
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self):
|
||||||
|
return dense_shape_and_type(self._matrix).dtype
|
||||||
|
|
||||||
|
@property
|
||||||
|
def eager_handle_data(self):
|
||||||
|
"""Return the matrix's handle data iff in eager mode."""
|
||||||
|
return _get_handle_data(self._matrix) if self._eager_mode else None
|
||||||
|
|
||||||
|
def conj(self):
|
||||||
|
return self._from_matrix(
|
||||||
|
math_ops.conj(self._matrix), self.eager_handle_data)
|
||||||
|
|
||||||
|
def hermitian_transpose(self):
|
||||||
|
"""Return the hermitian transpose of the matrix."""
|
||||||
|
return self._from_matrix(
|
||||||
|
sm_ops.sparse_matrix_transpose(
|
||||||
|
self._matrix, conjugate=True, type=self.dtype),
|
||||||
|
self.eager_handle_data)
|
||||||
|
|
||||||
|
def nnz(self):
|
||||||
|
"""Number of stored values, including explicit zeros."""
|
||||||
|
return sm_ops.sparse_matrix_nnz(self._matrix)
|
||||||
|
|
||||||
|
nonzero = nnz
|
||||||
|
|
||||||
|
def sorted_indices(self):
|
||||||
|
# TODO(ebrevdo): A more efficient implementation?
|
||||||
|
return self.to_sparse_tensor().indices
|
||||||
|
|
||||||
|
def transpose(self):
|
||||||
|
return self._from_matrix(
|
||||||
|
sm_ops.sparse_matrix_transpose(self._matrix, type=self.dtype),
|
||||||
|
self.eager_handle_data)
|
||||||
|
|
||||||
|
|
||||||
|
class CSRSparseMatrix(SparseMatrix):
|
||||||
|
"""(Optionally batched) CSR Sparse Matrix."""
|
||||||
|
|
||||||
|
def __init__(self, value, indices=None, name=None):
|
||||||
|
"""Construct a CSRSparseMatrix from a dense matrix or SparseTensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: A dense `2D` or `3D` Tensor or `SparseTensor`.
|
||||||
|
indices: The nonzero indices of `value`
|
||||||
|
(if `value` is not a `SparseTensor`).
|
||||||
|
name: Optional op name.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if `value` is a `SparseTensor` and `indices` is not `None`.
|
||||||
|
"""
|
||||||
|
super(CSRSparseMatrix, self).__init__()
|
||||||
|
if isinstance(value, sparse_tensor.SparseTensor):
|
||||||
|
if indices is not None:
|
||||||
|
raise ValueError("indices must be None if value is a SparseTensor.")
|
||||||
|
self._dtype = value.dtype
|
||||||
|
self._csr_matrix = sm_ops.sparse_tensor_to_csr_sparse_matrix(
|
||||||
|
indices=value.indices,
|
||||||
|
values=value.values,
|
||||||
|
dense_shape=value.dense_shape)
|
||||||
|
else:
|
||||||
|
value = ops.convert_to_tensor(value)
|
||||||
|
self._dtype = value.dtype
|
||||||
|
if indices is not None:
|
||||||
|
indices = ops.convert_to_tensor(indices, dtype=dtypes.int64)
|
||||||
|
else:
|
||||||
|
indices = array_ops.stop_gradient(array_ops.where(value))
|
||||||
|
self._csr_matrix = sm_ops.dense_to_csr_sparse_matrix(value, indices)
|
||||||
|
|
||||||
|
# Eager mode doesn't call shape inference functions, so we have to set the
|
||||||
|
# shape and dtype handle data directly.
|
||||||
|
if self._eager_mode:
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
self._csr_matrix._handle_data = _make_handle_data(value)
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _matrix(self):
|
||||||
|
return self._csr_matrix
|
||||||
|
|
||||||
|
def _from_matrix(self, matrix, handle_data=None):
|
||||||
|
assert isinstance(matrix, ops.Tensor) and matrix.dtype == dtypes.variant
|
||||||
|
ret = type(self).__new__(type(self))
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
ret._dtype = self._dtype
|
||||||
|
if self._eager_mode:
|
||||||
|
if matrix._handle_data is None:
|
||||||
|
matrix._handle_data = handle_data
|
||||||
|
assert matrix._handle_data is not None
|
||||||
|
ret._csr_matrix = matrix
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def to_dense(self):
|
||||||
|
return sm_ops.csr_sparse_matrix_to_dense(self._matrix, type=self.dtype)
|
||||||
|
|
||||||
|
def to_sparse_tensor(self):
|
||||||
|
r = sm_ops.csr_sparse_matrix_to_sparse_tensor(self._matrix, type=self.dtype)
|
||||||
|
return sparse_tensor.SparseTensor(
|
||||||
|
indices=r.indices, values=r.values, dense_shape=r.dense_shape)
|
@ -102,3 +102,8 @@ tf_py_logged_benchmark(
|
|||||||
name = "rnn_op_benchmark",
|
name = "rnn_op_benchmark",
|
||||||
target = "//tensorflow/python/kernel_tests:rnn_test",
|
target = "//tensorflow/python/kernel_tests:rnn_test",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_py_logged_benchmark(
|
||||||
|
name = "sparse_csr_matrix_ops_benchmark",
|
||||||
|
target = "//tensorflow/python/kernel_tests:sparse_csr_matrix_ops_test_py",
|
||||||
|
)
|
||||||
|
3
third_party/eigen3/BUILD
vendored
3
third_party/eigen3/BUILD
vendored
@ -18,7 +18,10 @@ EIGEN3_THIRD_PARTY_HEADERS = [
|
|||||||
"Eigen/LU",
|
"Eigen/LU",
|
||||||
"Eigen/Cholesky",
|
"Eigen/Cholesky",
|
||||||
"Eigen/Eigenvalues",
|
"Eigen/Eigenvalues",
|
||||||
|
"Eigen/OrderingMethods",
|
||||||
"Eigen/QR",
|
"Eigen/QR",
|
||||||
|
"Eigen/SparseCholesky",
|
||||||
|
"Eigen/SparseCore",
|
||||||
"Eigen/SVD",
|
"Eigen/SVD",
|
||||||
"unsupported/Eigen/MatrixFunctions",
|
"unsupported/Eigen/MatrixFunctions",
|
||||||
"unsupported/Eigen/SpecialFunctions",
|
"unsupported/Eigen/SpecialFunctions",
|
||||||
|
1
third_party/eigen3/Eigen/OrderingMethods
vendored
Normal file
1
third_party/eigen3/Eigen/OrderingMethods
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
#include "Eigen/OrderingMethods"
|
1
third_party/eigen3/Eigen/SparseCholesky
vendored
Normal file
1
third_party/eigen3/Eigen/SparseCholesky
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
#include "Eigen/SparseCholesky"
|
1
third_party/eigen3/Eigen/SparseCore
vendored
Normal file
1
third_party/eigen3/Eigen/SparseCore
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
#include "Eigen/SparseCore"
|
2
third_party/eigen3/LICENSE
vendored
2
third_party/eigen3/LICENSE
vendored
@ -533,10 +533,12 @@ Following applies to:
|
|||||||
./Eigen/src/MetisSupport/MetisSupport.h
|
./Eigen/src/MetisSupport/MetisSupport.h
|
||||||
./Eigen/StdVector
|
./Eigen/StdVector
|
||||||
./Eigen/Core
|
./Eigen/Core
|
||||||
|
./Eigen/OrderingMethods
|
||||||
./Eigen/SparseLU
|
./Eigen/SparseLU
|
||||||
./Eigen/StdList
|
./Eigen/StdList
|
||||||
./Eigen/StdDeque
|
./Eigen/StdDeque
|
||||||
./Eigen/SparseCholesky
|
./Eigen/SparseCholesky
|
||||||
|
./Eigen/SparseCore
|
||||||
./scripts/relicense.py
|
./scripts/relicense.py
|
||||||
./scripts/relicense.py
|
./scripts/relicense.py
|
||||||
./blas/BandTriangularSolver.h
|
./blas/BandTriangularSolver.h
|
||||||
|
Loading…
Reference in New Issue
Block a user