diff --git a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc index 30026f222a6..30c57ef287f 100644 --- a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc +++ b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc @@ -65,7 +65,8 @@ class SparseTensorDenseMatMulOp : public OpKernel { OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(a_indices->shape()), errors::InvalidArgument("Tensor 'a_indices' is not a matrix")); - OP_REQUIRES(ctx, a_indices->shape().dim_size(0) == a_values->NumElements(), + const int64 nnz = a_indices->shape().dim_size(0); + OP_REQUIRES(ctx, nnz == a_values->NumElements(), errors::InvalidArgument("Number of rows of a_indices does not " "match number of entries in a_values")); @@ -89,8 +90,28 @@ class SparseTensorDenseMatMulOp : public OpKernel { inner_left, " vs. ", inner_right, ". Did you forget a transpose? " "Dimensions of A: [", - a_shape_t(0), ", ", a_shape_t(1), "). Dimensions of B: ", - b->shape().DebugString())); + a_shape_t(0), ", ", a_shape_t(1), + "). Dimensions of B: ", b->shape().DebugString())); + + if (std::is_same::value) { + // The GPU implementation is optimized to use 32 bit indexing, so + // give a friendly error to the programmer early on if they + // exceed. + const int int32max = std::numeric_limits::max(); + OP_REQUIRES( + ctx, + (FastBoundsCheck(inner_left, int32max) && + FastBoundsCheck(inner_right, int32max) && + FastBoundsCheck(outer_left, int32max) && + FastBoundsCheck(outer_right, int32max) && + FastBoundsCheck(b->NumElements(), int32max) && + FastBoundsCheck(outer_left * outer_right, int32max) && + FastBoundsCheck(a_values->NumElements(), int32max)), + errors::InvalidArgument("Cannot use GPU for > 2^31 entry inputs")); + OP_REQUIRES(ctx, FastBoundsCheck(nnz * outer_right, int32max), + errors::InvalidArgument( + "Cannot use GPU when output.shape[1] * nnz(a) > 2^31")); + } TensorShape out_shape({outer_left, outer_right}); Tensor* out = nullptr; @@ -111,41 +132,13 @@ class SparseTensorDenseMatMulOp : public OpKernel { return; } - Tensor scratch; - - if (std::is_same::value) { - // The GPU implementation is optimized to use 32 bit indexing, so - // give a friendly error to the programmer early on if they exceed. - OP_REQUIRES( - ctx, - FastBoundsCheck(inner_left, std::numeric_limits::max()) && - FastBoundsCheck(inner_right, std::numeric_limits::max()) && - FastBoundsCheck(outer_left, std::numeric_limits::max()) && - FastBoundsCheck(outer_right, std::numeric_limits::max()) && - FastBoundsCheck(b->NumElements(), - std::numeric_limits::max()) && - FastBoundsCheck(out->NumElements(), - std::numeric_limits::max()) && - FastBoundsCheck(a_values->NumElements(), - std::numeric_limits::max()), - errors::InvalidArgument("Cannot use GPU for > 2^31 entry inputs")); - const int nnz = static_cast(a_values->NumElements()); - // Need nnz length vec scratch space on the GPU. - OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum::value, - TensorShape({nnz}), &scratch)); - } else { - // We don't need scratch space on the CPU. - OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum::value, - TensorShape({0}), &scratch)); - } - #define MAYBE_ADJOINT(ADJ_A, ADJ_B) \ if (adjoint_a_ == ADJ_A && adjoint_b_ == ADJ_B) { \ Status functor_status = functor::SparseTensorDenseMatMulFunctor< \ Device, T, Tindices, ADJ_A, \ ADJ_B>::Compute(ctx->eigen_device(), out->matrix(), \ a_indices->matrix(), a_values->vec(), \ - b->matrix(), scratch.vec()); \ + b->matrix()); \ OP_REQUIRES_OK(ctx, functor_status); \ } @@ -189,10 +182,9 @@ namespace functor { Status SparseTensorDenseMatMulFunctor< \ GPUDevice, T, Tindices, ADJ_A, \ ADJ_B>::Compute(const GPUDevice& d, typename TTypes::Matrix out, \ - typename TTypes::ConstMatrix a_indices, \ + TTypes::ConstMatrix a_indices, \ typename TTypes::ConstVec a_values, \ - typename TTypes::ConstMatrix b, \ - typename TTypes::Vec scratch); \ + typename TTypes::ConstMatrix b); \ extern template struct SparseTensorDenseMatMulFunctor< \ GPUDevice, T, Tindices, ADJ_A, ADJ_B>; @@ -255,8 +247,7 @@ struct SparseTensorDenseMatMulFunctor { static Status Compute(const CPUDevice& d, typename TTypes::Matrix out, typename TTypes::ConstMatrix a_indices, typename TTypes::ConstVec a_values, - typename TTypes::ConstMatrix b, - typename TTypes::Vec scratch) { + typename TTypes::ConstMatrix b) { const std::size_t nnz = a_values.size(); const std::size_t rhs_right = (ADJ_B ? b.dimension(0) : b.dimension(1)); const std::size_t lhs_right = (ADJ_B ? b.dimension(1) : b.dimension(0)); diff --git a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h index e707743f782..da131904949 100644 --- a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h +++ b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h @@ -28,11 +28,10 @@ namespace functor { template struct SparseTensorDenseMatMulFunctor { - static EIGEN_ALWAYS_INLINE Status - Compute(const Device& d, typename TTypes::Matrix out, - typename TTypes::ConstMatrix a_indices, - typename TTypes::ConstVec a_values, - typename TTypes::ConstMatrix b, typename TTypes::Vec scratch); + static EIGEN_ALWAYS_INLINE Status Compute( + const Device& d, typename TTypes::Matrix out, + typename TTypes::ConstMatrix a_indices, + typename TTypes::ConstVec a_values, typename TTypes::ConstMatrix b); }; template diff --git a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op_gpu.cu.cc b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op_gpu.cu.cc index 7266e0cf812..e261e42e0d3 100644 --- a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op_gpu.cu.cc +++ b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op_gpu.cu.cc @@ -20,71 +20,45 @@ limitations under the License. #include "tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h" #include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/util/cuda_kernel_helper.h" namespace tensorflow { typedef Eigen::GpuDevice GPUDevice; -namespace generator { - template -class SparseTensorDenseMatMulGPUGenerator { - public: - EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE SparseTensorDenseMatMulGPUGenerator( - typename TTypes::Tensor32Bit out, - typename TTypes::Tensor32Bit a_indices, - typename TTypes::Tensor32Bit a_values, - typename TTypes::Tensor32Bit b) - : out_(out), - lhs_index_a_(ADJ_A ? 1 : 0), - rhs_index_a_(ADJ_A ? 0 : 1), - a_indices_(a_indices), - a_values_(a_values), - lhs_right_size(ADJ_B ? b.dimension(1) : b.dimension(0)), - maybe_adjoint_b_( - functor::MaybeAdjoint::Tensor32Bit, - ADJ_B>(b)) {} - - EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T - operator()(const Eigen::array& j_and_ix) const { -#ifdef __CUDA_ARCH__ - const int j = j_and_ix[0]; - const int ix = j_and_ix[1]; - int m = a_indices_(ix, lhs_index_a_); - int k = a_indices_(ix, rhs_index_a_); - assert(k < lhs_right_size); - assert(m < out_.dimension(0)); - // If asserts are disabled, the caller is violating the sparse - // tensor index contract, and so we return invalid results. - // Force returning NaNs to try to signal that something is amiss. - T b_value; - if (k >= lhs_right_size || m >= out_.dimension(0)) { - m = 0; - k = 0; - b_value = std::numeric_limits::quiet_NaN(); - } else { - b_value = maybe_adjoint_b_(k, j); +__global__ void SparseTensorDenseMatMulKernel(int nnz, int m, int b_rows, + int b_cols, int p, + const Tindices* a_indices, + const T* a_values, const T* b, + T* out) { + // out_{ij} = sum_k {a_ik b_kj} + // out = A * B', out_{ij} = sum_k {a_ik (b')_kj}; b'_{kj} = b_{jk} + const int n = (ADJ_B) ? b_cols : b_rows; + CUDA_1D_KERNEL_LOOP(index, nnz * p) { + const int a_ix = index / p; + const int j = index % p; + const int i = ldg(a_indices + 2 * a_ix + ((ADJ_A) ? 1 : 0)); + const int k = ldg(a_indices + 2 * a_ix + ((ADJ_A) ? 0 : 1)); + if (!FastBoundsCheck(i, m)) { + continue; // Nowhere to signal an error :( } - atomicAdd(&out_(m, j), a_values_(ix) * b_value); -#else - assert(false && "This should only be run on the device"); -#endif - // Return something - return T(0); + // out[i, j] + T* out_location = out + i * p + j; + if (!FastBoundsCheck(k, n)) { + CudaAtomicAdd(out_location, std::numeric_limits::quiet_NaN()); + continue; + } + + // a_value == (ADJ_A) ? a[k, i] : a[i, k] + const T a_value = ldg(a_values + a_ix); + + // b_value == (ADJ_B) ? b[j, k] : b[k, j] + const T b_value = ldg(b + ((ADJ_B) ? j * b_cols + k : k * b_cols + j)); + CudaAtomicAdd(out_location, a_value * b_value); } - - private: - mutable typename TTypes::Tensor32Bit out_; - const int lhs_index_a_; - const int rhs_index_a_; - typename TTypes::Tensor32Bit a_indices_; - typename TTypes::Tensor32Bit a_values_; - const int lhs_right_size; - functor::MaybeAdjoint::Tensor32Bit, ADJ_B> - maybe_adjoint_b_; -}; - -} // namespace generator +} namespace functor { @@ -94,51 +68,23 @@ struct SparseTensorDenseMatMulFunctor { Compute(const GPUDevice& d, typename TTypes::Matrix out, typename TTypes::ConstMatrix a_indices, typename TTypes::ConstVec a_values, - typename TTypes::ConstMatrix b, typename TTypes::Vec scratch) { - generator::SparseTensorDenseMatMulGPUGenerator - sparse_tensor_dense_matmul_generator(To32Bit(out), To32Bit(a_indices), - To32Bit(a_values), To32Bit(b)); - To32Bit(out).device(d) = To32Bit(out).constant(T(0)); + typename TTypes::ConstMatrix b) { + out.device(d) = out.constant(T(0)); int nnz = a_values.size(); - int n = (ADJ_B) ? b.dimension(0) : b.dimension(1); + // out = A * B, A is [m x n] and B is [n x p], out is [m x p] + int m = out.dimension(0); + int p = out.dimension(1); + int b_rows = b.dimension(0); + int b_cols = b.dimension(1); -#if !defined(EIGEN_HAS_INDEX_LIST) - Eigen::Tensor::Dimensions matrix_1_by_nnz{{ 1, nnz }}; - Eigen::array n_by_1{{ n, 1 }}; - Eigen::array reduce_on_rows{{ 0 }}; -#else - Eigen::IndexList, int> matrix_1_by_nnz; - matrix_1_by_nnz.set(1, nnz); - Eigen::IndexList > n_by_1; - n_by_1.set(0, n); - Eigen::IndexList > reduce_on_rows; -#endif + // TODO(ebrevdo): Should this be alpha * nnz instead of + // out.size()? Perhaps p * nnz ? + CudaLaunchConfig config = GetCudaLaunchConfig(p * nnz, d); - // How this works: the generator iterates over (j, ix) where j - // iterates from 0 .. n - 1 and ix iterates from - // 0 .. nnz - 1. A side effect of the generator is to accumulate - // the products of values in A and B into the appropriate location - // in the dense matrix out. In order to run the iteration, - // we take a smaller variable and broadcast to a size (n, nnz). - // This is the scratch variable. In order to enforce execution, - // we have to perform assignment back into scratch (taking the sum). - // We don't care what gets assigned to scratch - only the side effect - // of the execution in the generator. - // - // Note it's not sufficient that scratch be a scalar, and to - // broadcast it to a matrix. Eigen splits the computation not - // based on the largest intermediate shape (the size of the - // broadcast of scratch) but based on the output shape. So - // scratch needs to be a vector at least. - // - // Note also that only float type is supported because the - // atomicAdd operation is only supported for floats in hardware. - To32Bit(scratch).device(d) = - To32Bit(scratch) - .reshape(matrix_1_by_nnz) - .broadcast(n_by_1) - .generate(sparse_tensor_dense_matmul_generator) - .sum(reduce_on_rows); + SparseTensorDenseMatMulKernel + <<>>( + nnz, m, b_rows, b_cols, p, a_indices.data(), a_values.data(), + b.data(), out.data()); return Status::OK(); } diff --git a/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py b/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py index 80991751860..a0bd178e247 100644 --- a/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py +++ b/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py @@ -29,6 +29,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops @@ -161,6 +162,46 @@ class SparseTensorDenseMatMulTest(test.TestCase): sparse_ops.sparse_tensor_dense_matmul( sparse_t, dense_t, adjoint_a=True).eval() + def testInvalidIndicesForSparseTensorDenseMatmulOnGPU(self): + # Note: use_gpu=False because nice errors are only returned from CPU kerne + if not test.is_gpu_available(): + return + with self.test_session(use_gpu=True): + indices = np.array([[1, 10]]).astype(np.int64) + values = np.array([10]).astype(np.float32) + shape = [3, 2] + sparse_t = sparse_tensor.SparseTensor(indices, values, shape) + + # Test multiplying by both a small and large dense matrix, to hit + # both cases in the kernel. + dense_t = np.matrix([[1] * 5, [2] * 5], dtype=np.float32) + expected_t = np.array([[0] * 5, [np.nan] * 5, [0] * 5], dtype=np.float32) + self.assertAllClose(expected_t, + sparse_ops.sparse_tensor_dense_matmul( + sparse_t, dense_t).eval()) + dense_t = np.matrix([[1] * 500, [2] * 500], dtype=np.float32) + expected_t = np.array( + [[0] * 500, [np.nan] * 500, [0] * 500], dtype=np.float32) + self.assertAllClose(expected_t, + sparse_ops.sparse_tensor_dense_matmul( + sparse_t, dense_t).eval()) + + # Repeat with adjoint_a, now the error is that the sparse index + # is OOO w.r.t. the output. The GPU kernel can't do much here, + # so it just doesn't accumulate. + + dense_t = np.matrix([[1] * 5, [2] * 5, [3] * 5], dtype=np.float32) + expected_t = np.array([[0] * 5, [0] * 5], dtype=np.float32) + self.assertAllClose(expected_t, + sparse_ops.sparse_tensor_dense_matmul( + sparse_t, dense_t, adjoint_a=True).eval()) + + dense_t = np.matrix([[1] * 500, [2] * 500, [3] * 500], dtype=np.float32) + expected_t = np.array([[0] * 500, [0] * 500], dtype=np.float32) + self.assertAllClose(expected_t, + sparse_ops.sparse_tensor_dense_matmul( + sparse_t, dense_t, adjoint_a=True).eval()) + # Tests setting one dimension to be a high value. def _testLarge(self, np_dtype): r1 = np.random.randint(6000, 20000) @@ -175,9 +216,12 @@ class SparseTensorDenseMatMulTest(test.TestCase): y = _maybe_complex(np.random.randn(k, n).astype(np_dtype)) - self._testMatmul(x, y) + self._testMatmul(x, y, adjoint_a=False, adjoint_b=False) + self._testMatmul(x.transpose(), y, adjoint_a=True, adjoint_b=False) + self._testMatmul(x, y.transpose(), adjoint_a=False, adjoint_b=True) + self._testMatmul( + x.transpose(), y.transpose(), adjoint_a=True, adjoint_b=True) - def testLarge(self): np.random.seed(127) # Repeatable results self._testLarge(np.float32) self._testLarge(np.float64) @@ -221,7 +265,9 @@ def _sparse_tensor_dense_vs_dense_matmul_benchmark_dense(x, y, adjoint_a, lambda t, _: t < iterations, body, (t0, v0), parallel_iterations=1, - back_prop=False) + back_prop=False, + shape_invariants=(tensor_shape.TensorShape(()), + tensor_shape.TensorShape(None))) return [final] return _timeit @@ -246,7 +292,9 @@ def _sparse_tensor_dense_vs_dense_matmul_benchmark_sparse(x_ind, x_val, x_shape, lambda t, _: t < iterations, body, (t0, v0), parallel_iterations=1, - back_prop=False) + back_prop=False, + shape_invariants=(tensor_shape.TensorShape(()), + tensor_shape.TensorShape(None))) return [final] return _timeit @@ -291,7 +339,7 @@ def sparse_tensor_dense_vs_dense_matmul_benchmark(thresh, if skip_dense: delta_dense = float("nan") else: - with session.Session("", config=config, graph=ops.Graph()) as sess: + with session.Session(config=config, graph=ops.Graph()) as sess: if not use_gpu: with ops.device("/cpu:0"): x_t = constant_op.constant(x) @@ -299,12 +347,12 @@ def sparse_tensor_dense_vs_dense_matmul_benchmark(thresh, ops_fn = _sparse_tensor_dense_vs_dense_matmul_benchmark_dense( x_t, y_t, adjoint_a, adjoint_b) else: - x_t = constant_op.constant(x) - y_t = constant_op.constant(y) - ops_fn = _sparse_tensor_dense_vs_dense_matmul_benchmark_dense(x_t, y_t, - adjoint_a, - adjoint_b) - delta_dense = _timer(sess, ops_fn, 1000) + with ops.device("/gpu:0"): + x_t = constant_op.constant(x) + y_t = constant_op.constant(y) + ops_fn = _sparse_tensor_dense_vs_dense_matmul_benchmark_dense( + x_t, y_t, adjoint_a, adjoint_b) + delta_dense = _timer(sess, ops_fn, 200) # Using sparse_tensor_dense_matmul. with session.Session("", config=config, graph=ops.Graph()) as sess: @@ -317,13 +365,14 @@ def sparse_tensor_dense_vs_dense_matmul_benchmark(thresh, ops_fn = _sparse_tensor_dense_vs_dense_matmul_benchmark_sparse( x_ind, x_val, x_shape, y_t, adjoint_a, adjoint_b) else: - x_ind = constant_op.constant(np.vstack(np.where(x)).astype(np.int64).T) - x_val = constant_op.constant(x[np.where(x)]) - x_shape = constant_op.constant(np.array(x.shape).astype(np.int64)) - y_t = constant_op.constant(y) - ops_fn = _sparse_tensor_dense_vs_dense_matmul_benchmark_sparse( - x_ind, x_val, x_shape, y_t, adjoint_a, adjoint_b) - delta_sparse = _timer(sess, ops_fn, 1000) + with ops.device("/gpu:0"): + x_ind = constant_op.constant(np.vstack(np.where(x)).astype(np.int64).T) + x_val = constant_op.constant(x[np.where(x)]) + x_shape = constant_op.constant(np.array(x.shape).astype(np.int64)) + y_t = constant_op.constant(y) + ops_fn = _sparse_tensor_dense_vs_dense_matmul_benchmark_sparse( + x_ind, x_val, x_shape, y_t, adjoint_a, adjoint_b) + delta_sparse = _timer(sess, ops_fn, 200) print("%g \t %d \t %s \t %d \t %d \t %g \t %g \t %g" % (1 - thresh, n, use_gpu, m, k, delta_dense, delta_sparse, @@ -340,7 +389,7 @@ def main(_): "\t dt(sparse)/dt(dense)") for thresh in (0.99, 0.8, 0.5, 0.2): - for n in (1, 10, 25): + for n in (50, 100): for use_gpu in (True, False): for m in (100, 1000): for k in (100, 1000):