From 867d3c97082cb2d26036d129ef7b51f3867a19d3 Mon Sep 17 00:00:00 2001
From: Derek Murray <mrry@google.com>
Date: Wed, 19 Feb 2020 15:45:36 -0800
Subject: [PATCH] [SparseTensor] Optimize the `tf.sparse.to_dense()`
 implementation.

This change includes several optimizations:

1. Introduce `SparseTensor::IndicesValidVectorFastPath()`, for validating the
   indices of a 1-D SparseTensor. The optimized code is similar to
   `IndicesValid32BitFastPath()`, which optimistically assumes that the tensor
   is valid and falls back to slower code in the failure case, except it does
   not have the 32-bit limitation. The compiler is able to vectorize the loop
   over the indices, for increased throughput.

2. Implement fast paths for 1-D and 2-D inputs in `SparseTensor::ToDense()`.
   The main win here comes from avoiding the data-dependent loop over
   dimensions when computing the index of the output value. We also avoid
   an unnecessary integer multiplication (by 1) in each case.

3. Minor optimizations to the 3+-D case in `SparseTensor::ToDense()`, avoiding
   unnecessary calls to `TensorShape::dim_size()` and using pointer arithmetic
   rather than Eigen logic to dereference index elements.

4. Minor optimizations to the `SparseTensor::Create()` method, which now
   assigns directly to the relevant fields of the result instead of invoking
   the `SparseTensor` constructor and the move assignment operator. In this
   case the existing move logic wasn't saving us much, because the `Tensor` and
   `gtl::InlinedVector` move constructors still have to copy quite a lot of
   data.

5. Minor optimizations to the `SparseToDense::Compute()` method. In particular,
   we avoid allocating a temporary tensor for the indices when the input is
   DT_INT64 (which is the common case, since all `tf.SparseTensor` objects have
   64-bit indices).

PiperOrigin-RevId: 296075159
Change-Id: I0b051621920aec9b2a8dc6c7ecbf55e5b2d59098
---
 tensorflow/core/kernels/sparse_to_dense_op.cc | 44 +++++-----
 tensorflow/core/util/sparse/sparse_tensor.cc  | 54 +++++++++++--
 tensorflow/core/util/sparse/sparse_tensor.h   | 81 +++++++++++++------
 3 files changed, 133 insertions(+), 46 deletions(-)

diff --git a/tensorflow/core/kernels/sparse_to_dense_op.cc b/tensorflow/core/kernels/sparse_to_dense_op.cc
index d9626052b0c..da4e7e070db 100644
--- a/tensorflow/core/kernels/sparse_to_dense_op.cc
+++ b/tensorflow/core/kernels/sparse_to_dense_op.cc
@@ -20,14 +20,13 @@ limitations under the License.
 
 #define EIGEN_USE_THREADS
 
-#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-
 #include <numeric>
 #include <sstream>
 #include <string>
 #include <unordered_map>
 #include <utility>
 
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
 #include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/framework/register_types.h"
 #include "tensorflow/core/framework/tensor.h"
@@ -35,6 +34,7 @@ limitations under the License.
 #include "tensorflow/core/lib/core/status.h"
 #include "tensorflow/core/lib/gtl/inlined_vector.h"
 #include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/util/ptr_util.h"
 #include "tensorflow/core/util/sparse/sparse_tensor.h"
 
 namespace tensorflow {
@@ -93,36 +93,44 @@ class SparseToDense : public OpKernel {
     Tensor* output = nullptr;
     OP_REQUIRES_OK(c, c->allocate_output(0, output_tensor_shape, &output));
 
-    TensorShape ix_shape({num_elems, num_dims});
-    Tensor indices_shaped(DT_INT64, ix_shape);
-    if (indices.dtype() == DT_INT64) {
-      CHECK(indices_shaped.CopyFrom(indices, ix_shape));
+    const Tensor* indices_shaped;
+    std::unique_ptr<Tensor> indices_shaped_holder;
+    if (indices.dtype() == DT_INT64 && indices.dims() == 2) {
+      indices_shaped = &indices;
     } else {
-      indices_shaped.matrix<int64>() =
-          indices.shaped<Index, 2>(ix_shape.dim_sizes()).template cast<int64>();
+      TensorShape ix_shape({num_elems, num_dims});
+      indices_shaped_holder = MakeUnique<Tensor>(DT_INT64, ix_shape);
+      indices_shaped = indices_shaped_holder.get();
+      if (indices.dtype() == DT_INT64) {
+        CHECK(indices_shaped_holder->CopyFrom(indices, ix_shape));
+      } else {
+        indices_shaped_holder->matrix<int64>() =
+            indices.shaped<Index, 2>(ix_shape.dim_sizes())
+                .template cast<int64>();
+      }
     }
 
     // If we received a scalar, we'll need to create a new
     // tensor with copies of the values as a vec.
-    // TODO(ebrevdo): find a way to avoid this temp allocation.
-    Tensor sparse_values_b;
+    const Tensor* sparse_values_b;
+    std::unique_ptr<Tensor> sparse_values_b_holder;
 
     if (TensorShapeUtils::IsScalar(sparse_values.shape())) {
-      OP_REQUIRES_OK(
-          c, c->allocate_temp(DataTypeToEnum<T>::value,
-                              TensorShape({num_elems}), &sparse_values_b));
-      sparse_values_b.vec<T>().setConstant(sparse_values.scalar<T>()());
+      sparse_values_b_holder = MakeUnique<Tensor>(DataTypeToEnum<T>::value,
+                                                  TensorShape({num_elems}));
+      sparse_values_b = sparse_values_b_holder.get();
+      sparse_values_b_holder->vec<T>().setConstant(sparse_values.scalar<T>()());
     } else {
-      sparse_values_b = sparse_values;
+      sparse_values_b = &sparse_values;
     }
 
     // Assume SparseTensor is lexicographically sorted.
     gtl::InlinedVector<int64, 8> order(output->shape().dims());
     std::iota(order.begin(), order.end(), 0);
     sparse::SparseTensor st;
-    OP_REQUIRES_OK(c,
-                   sparse::SparseTensor::Create(indices_shaped, sparse_values_b,
-                                                output->shape(), order, &st));
+    OP_REQUIRES_OK(
+        c, sparse::SparseTensor::Create(*indices_shaped, *sparse_values_b,
+                                        output->shape(), order, &st));
 
     if (validate_indices_) {
       OP_REQUIRES_OK(c, st.IndicesValid());
diff --git a/tensorflow/core/util/sparse/sparse_tensor.cc b/tensorflow/core/util/sparse/sparse_tensor.cc
index e58bd95f5a6..256ba57f1b6 100644
--- a/tensorflow/core/util/sparse/sparse_tensor.cc
+++ b/tensorflow/core/util/sparse/sparse_tensor.cc
@@ -65,7 +65,11 @@ Status GetDimsFromIx(const Tensor& ix, int* result) {
     return errors::InvalidArgument("Shape rank must be SparseTensor rank.");
   }
 
-  *result = SparseTensor(std::move(ix), std::move(vals), shape, order);
+  result->ix_ = std::move(ix);
+  result->vals_ = std::move(vals);
+  result->shape_.assign(shape.begin(), shape.end());
+  result->order_.assign(order.begin(), order.end());
+  result->dims_ = dims;
   return Status::OK();
 }
 
@@ -108,6 +112,37 @@ SparseTensor::SparseTensor(Tensor ix, Tensor vals, const VarDimArray shape,
   DCHECK_EQ(shape.size(), dims_) << "Shape rank must be SparseTensor rank.";
 }
 
+// Optimized version of `IndicesValid()` with the following requirements:
+// * The sparse tensor is one-dimensional.
+//
+// Returns true if the indices are valid, otherwise false.
+// NOTE(mrry): If this method returns false, call IndicesValidHelper<true>()
+// to obtain a meaningful error message.
+bool SparseTensor::IndicesValidVectorFastPath() const {
+  DCHECK_EQ(shape_.size(), 1);
+  DCHECK_EQ(order_[0], 0);
+
+  const int64 max_index = shape_[0];
+
+  // We maintain separate bools for each validation predicate to enable
+  // vectorization across loop iterations.
+  bool index_in_range_valid = true;
+  bool order_valid = true;
+
+  int64 prev_index = -1;
+  const auto ix_t = ix_.matrix<int64>();
+  const int64* const index_base_ptr = ix_t.data();
+
+  for (std::size_t n = 0; n < ix_t.dimension(0); ++n) {
+    const int64 index = index_base_ptr[n];
+    index_in_range_valid = index_in_range_valid & (index < max_index);
+    order_valid = order_valid & (index > prev_index);
+    prev_index = index;
+  }
+
+  return index_in_range_valid & order_valid;
+}
+
 // Optimized version of `IndicesValid()` with the following requirements:
 // * The sparse tensor is two-dimensional.
 // * The tensor's indices are in the "standard" (lexicographic) order.
@@ -116,7 +151,7 @@ SparseTensor::SparseTensor(Tensor ix, Tensor vals, const VarDimArray shape,
 // Returns true if the indices are valid, otherwise false.
 // NOTE(mrry): If this method returns false, call IndicesValidHelper<true>()
 // to obtain a meaningful error message.
-bool SparseTensor::IndicesValid32BitFastPath() const {
+bool SparseTensor::IndicesValidMatrix32BitFastPath() const {
   const auto ix_t = ix_.matrix<int64>();
   const int64* const shape_ptr = shape_.data();
 
@@ -241,6 +276,10 @@ Status SparseTensor::IndicesValidHelper() const {
 }
 
 Status SparseTensor::IndicesValid() const {
+  if (shape_.size() == 1 && IndicesValidVectorFastPath()) {
+    return Status::OK();
+  }
+
   bool standard_order = true;
   for (size_t i = 0; i < order_.size(); ++i) {
     if (order_[i] < 0) {
@@ -252,9 +291,14 @@ Status SparseTensor::IndicesValid() const {
   }
 
   if (standard_order) {
-    if (shape_.size() == 2 && shape_[0] <= std::numeric_limits<int32>::max() &&
-        shape_[1] <= std::numeric_limits<int32>::max()) {
-      if (IndicesValid32BitFastPath()) {
+    if (shape_.size() == 1) {
+      if (IndicesValidVectorFastPath()) {
+        return Status::OK();
+      }
+    } else if (shape_.size() == 2 &&
+               shape_[0] <= std::numeric_limits<int32>::max() &&
+               shape_[1] <= std::numeric_limits<int32>::max()) {
+      if (IndicesValidMatrix32BitFastPath()) {
         return Status::OK();
       }
     }
diff --git a/tensorflow/core/util/sparse/sparse_tensor.h b/tensorflow/core/util/sparse/sparse_tensor.h
index 03ae4fe3f68..2654d126e86 100644
--- a/tensorflow/core/util/sparse/sparse_tensor.h
+++ b/tensorflow/core/util/sparse/sparse_tensor.h
@@ -201,7 +201,14 @@ class SparseTensor {
     return vec;
   }
 
-  bool IndicesValid32BitFastPath() const;
+  // Optimized implementation of `IndicesValid` for 1-D sparse tensors.
+  // REQUIRES: `shape_.size() == 1`.
+  bool IndicesValidVectorFastPath() const;
+
+  // Optimized implementation of `IndicesValid` for 2-D sparse tensors whose
+  // indices fit within the range of an `int32`.
+  // REQUIRES: `shape_.size() == 2`.
+  bool IndicesValidMatrix32BitFastPath() const;
 
   template <bool standard_order>
   Status IndicesValidHelper() const;
@@ -354,32 +361,60 @@ inline bool SparseTensor::ToDense(Tensor* out, bool initialize) {
   if (!ValidateAndInitializeToDense<T>(out, initialize)) return false;
 
   auto out_t = out->flat<T>();
-  auto ix_t = ix_.matrix<int64>();
   auto vals_t = vals_.vec<T>();
+  auto ix_t = ix_.matrix<int64>();
+  const int64* const ix_ptr = ix_t.data();
 
-  std::vector<int64> strides(dims_);
-  const auto& out_shape = out->shape();
-  if (dims_ > 0) {
-    strides[dims_ - 1] = 1;
-  }
-  for (int d = dims_ - 2; d >= 0; --d) {
-    strides[d] = strides[d + 1] * out_shape.dim_size(d + 1);
-  }
-
-  for (int n = 0; n < vals_t.dimension(0); ++n) {
-    bool invalid_dims = false;
-    int64 ix = 0;
-    for (int d = 0; d < dims_; ++d) {
-      const int64 ix_n_d = internal::SubtleMustCopy(ix_t(n, d));
-      if (!FastBoundsCheck(ix_n_d, out_shape.dim_size(d))) {
-        invalid_dims = true;
-      }
-      ix += strides[d] * ix_n_d;
+  if (dims_ == 1) {
+    // Fast path for sparse vectors.
+    const int64 out_length = out->shape().dim_size(0);
+    for (int n = 0; n < vals_t.dimension(0); ++n) {
+      const int64 index = internal::SubtleMustCopy(ix_ptr[n]);
+      if (!FastBoundsCheck(index, out_length)) return false;
+      out_t(index) = vals_t(n);
     }
-    if (invalid_dims) return false;
-    out_t(ix) = vals_t(n);
+    return true;
+  } else if (dims_ == 2) {
+    // Fast path for sparse matrices.
+    const auto& out_shape = out->shape();
+    const int64 out_rows = out_shape.dim_size(0);
+    const int64 out_cols = out_shape.dim_size(1);
+    for (int n = 0; n < vals_t.dimension(0); ++n) {
+      const int64 row_index = internal::SubtleMustCopy(ix_ptr[n * 2]);
+      const int64 col_index = internal::SubtleMustCopy(ix_ptr[n * 2 + 1]);
+      if (!(FastBoundsCheck(row_index, out_rows) &&
+            FastBoundsCheck(col_index, out_cols))) {
+        return false;
+      }
+      out_t(row_index * out_cols + col_index) = vals_t(n);
+    }
+    return true;
+  } else {
+    // General path for N-dimensional sparse tensors.
+    gtl::InlinedVector<int64, 4> strides(dims_);
+    const auto& out_shape = out->shape().dim_sizes();
+    if (dims_ > 0) {
+      strides[dims_ - 1] = 1;
+    }
+    for (int d = dims_ - 2; d >= 0; --d) {
+      strides[d] = strides[d + 1] * out_shape[d + 1];
+    }
+
+    for (int n = 0; n < vals_t.dimension(0); ++n) {
+      bool invalid_dims = false;
+      int64 ix = 0;
+      for (int d = 0; d < dims_; ++d) {
+        const int64 ix_n_d = internal::SubtleMustCopy(ix_ptr[n * dims_ + d]);
+        if (!FastBoundsCheck(ix_n_d, out_shape[d])) {
+          invalid_dims = true;
+        }
+        ix += strides[d] * ix_n_d;
+      }
+      if (invalid_dims) return false;
+      out_t(ix) = vals_t(n);
+    }
+    return true;
   }
-  return true;
 }
 
 template <typename T>