From 3b336d3173d5c8e84531bd97cc0e452ac9a56a7e Mon Sep 17 00:00:00 2001
From: Derek Murray <mrry@google.com>
Date: Fri, 24 Apr 2020 15:30:26 -0700
Subject: [PATCH] [Grappler] Add arithmetic optimizer stage for optimizing
 `tf.nn.embedding_lookup_sparse()`.

This optimization eliminates unnecessary `tf.unique()` and `tf.gather()` operations from `tf.nn.embedding_lookup_sparse()` when the embeddings are unpartitioned (e.g. at inference time) and weights are not used. Instead, the `tf.sparse.segment_<combiner>()` operation is applied directly to the embedding matrix without uniquifying the IDs.

PiperOrigin-RevId: 308338718
Change-Id: I6e689610a8f4f3dd0a3e8af77cce609ac6d4f9f9
---
 tensorflow/core/grappler/op_types.cc          |  19 ++++
 tensorflow/core/grappler/op_types.h           |   3 +
 .../optimizers/arithmetic_optimizer.cc        | 104 ++++++++++++++++++
 .../optimizers/arithmetic_optimizer.h         |   1 +
 .../optimizers/arithmetic_optimizer_test.cc   |  44 ++++++++
 .../arithmetic_optimizer_test_utils.h         |   6 +
 .../feature_column/feature_column_test.py     |   2 +-
 7 files changed, 178 insertions(+), 1 deletion(-)

diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc
index 8b8df527041..9e3b401154a 100644
--- a/tensorflow/core/grappler/op_types.cc
+++ b/tensorflow/core/grappler/op_types.cc
@@ -76,6 +76,15 @@ bool IsAnyMin(const NodeDef& node) {
   return op == "Min" || op == "SegmentMin" || op == "UnsortedSegmentMin";
 }
 
+bool IsAnySparseSegmentReduction(const NodeDef& node) {
+  const auto& op = node.op();
+  return op == "SparseSegmentSum" || op == "SparseSegmentSumWithNumSegments" ||
+         op == "SparseSegmentMean" ||
+         op == "SparseSegmentMeanWithNumSegments" ||
+         op == "SparseSegmentSqrtN" ||
+         op == "SparseSegmentSqrtNWithNumSegments";
+}
+
 bool IsApproximateEqual(const NodeDef& node) {
   return node.op() == "ApproximateEqual";
 }
@@ -268,6 +277,11 @@ bool IsFusedBatchNormGrad(const NodeDef& node) {
          op == "FusedBatchNormGradV3";
 }
 
+bool IsGather(const NodeDef& node) {
+  const auto& op = node.op();
+  return op == "Gather" || op == "GatherV2";
+}
+
 bool IsGreater(const NodeDef& node) { return node.op() == "Greater"; }
 
 bool IsGreaterEqual(const NodeDef& node) { return node.op() == "GreaterEqual"; }
@@ -589,6 +603,11 @@ bool IsTruncateDiv(const NodeDef& node) { return node.op() == "TruncateDiv"; }
 
 bool IsTruncateMod(const NodeDef& node) { return node.op() == "TruncateMod"; }
 
+bool IsUnique(const NodeDef& node) {
+  const auto& op = node.op();
+  return op == "Unique" || op == "UniqueV2";
+}
+
 bool IsUnpack(const NodeDef& node) { return node.op() == "Unpack"; }
 
 bool IsVariable(const NodeDef& node) {
diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h
index 88d81c5b202..b1624ac70c6 100644
--- a/tensorflow/core/grappler/op_types.h
+++ b/tensorflow/core/grappler/op_types.h
@@ -34,6 +34,7 @@ bool IsAnyMax(const NodeDef& node);
 bool IsAnyMaxPool(const NodeDef& node);
 bool IsAnyMin(const NodeDef& node);
 bool IsAnyMul(const NodeDef& node);
+bool IsAnySparseSegmentReduction(const NodeDef& node);
 bool IsApproximateEqual(const NodeDef& node);
 bool IsArg(const NodeDef& node);
 bool IsArgMax(const NodeDef& node);
@@ -81,6 +82,7 @@ bool IsFloorMod(const NodeDef& node);
 bool IsFusedBatchNorm(const NodeDef& node);
 bool IsFusedBatchNormEx(const NodeDef& node);
 bool IsFusedBatchNormGrad(const NodeDef& node);
+bool IsGather(const NodeDef& node);
 bool IsGreater(const NodeDef& node);
 bool IsGreaterEqual(const NodeDef& node);
 bool IsHistogramSummary(const NodeDef& node);
@@ -187,6 +189,7 @@ bool IsTile(const NodeDef& node);
 bool IsTranspose(const NodeDef& node);
 bool IsTruncateDiv(const NodeDef& node);
 bool IsTruncateMod(const NodeDef& node);
+bool IsUnique(const NodeDef& node);
 bool IsUnpack(const NodeDef& node);
 bool IsVariable(const NodeDef& node);
 bool IsWhile(const NodeDef& node);
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index 200572456c3..b9502ffb45e 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -3466,6 +3466,108 @@ class RemoveStackSliceSameAxis : public ArithmeticOptimizerStage {
   }
 };
 
+// Eliminates unnecessary copies during sparse embedding lookup operations.
+//
+// For non-partitioned variables, the `tf.nn.embedding_lookup_sparse()` function
+// generates code of the form:
+//
+//     embeddings = <a 2D Tensor>
+//     sparse_ids = <a tf.int64 SparseTensor>
+//     segment_ids = sparse_ids.indices[:, 0]
+//     ids, idx = tf.unique(sparse_ids.values)
+//     gathered_rows = tf.gather(params, ids)
+//     result = tf.sparse.segment_<combiner>(gathered_rows, idx, segment_ids)
+//
+// In this case, all of the work in `tf.unique()` and `tf.gather()`
+// can be avoided by passing the full embeddings to
+// `tf.sparse.segment_<combiner>()` and performing the same amount of
+// computation (but fewer copies and allocations) as follows:
+//
+//     embeddings = <a 2D Tensor>
+//     sparse_ids = <a tf.int64 SparseTensor>
+//     segment_ids = sparse_ids.indices[:, 0]
+//     result = tf.sparse.segment_<combiner>(
+//          embeddings, sparse_ids.values, segment_ids)
+class SimplifyEmbeddingLookupStage : public ArithmeticOptimizerStage {
+ public:
+  explicit SimplifyEmbeddingLookupStage(
+      const GraphOptimizerContext& ctx,
+      const ArithmeticOptimizerContext& ctx_ext)
+      : ArithmeticOptimizerStage("SimplifyEmbeddingLookupStage", ctx, ctx_ext) {
+  }
+  ~SimplifyEmbeddingLookupStage() override = default;
+
+  bool IsSupported(const NodeDef* node) const override {
+    return IsAnySparseSegmentReduction(*node);
+  }
+
+  Status TrySimplify(NodeDef* reduction_node,
+                     string* simplified_node_name) override {
+    if (IsInPreserveSet(*reduction_node)) return Status::OK();
+
+    // Input 0 (data) of the reduction node must be a tf.gather() on the 0th
+    // axis.
+    NodeDef* gather_node = nullptr;
+    TF_RETURN_IF_ERROR(GetInputNode(reduction_node->input(0), &gather_node));
+    if (!IsGather(*gather_node) || IsInPreserveSet(*gather_node) ||
+        gather_node->device() != reduction_node->device())
+      return Status::OK();
+    if (gather_node->op() == "GatherV2" && !IsAxis0(*gather_node, 2))
+      return Status::OK();
+
+    // Input 1 (indices) of the gather node must be a tf.unique() on the 0th
+    // axis.
+    NodeDef* unique_node = nullptr;
+    TF_RETURN_IF_ERROR(GetInputNode(gather_node->input(1), &unique_node));
+    if (!IsUnique(*unique_node) || IsInPreserveSet(*unique_node) ||
+        unique_node->device() != gather_node->device())
+      return Status::OK();
+    if (unique_node->op() == "UniqueV2" && !IsAxis0(*unique_node, 1))
+      return Status::OK();
+
+    DataType unique_element_type;
+    TF_RETURN_IF_ERROR(GetNodeAttr(*unique_node, "T", &unique_element_type));
+
+    // Input 1 (indices) of the reduction node must be output 1 of the unique
+    // node.
+    const TensorId idx_tensor = ParseTensorName(reduction_node->input(1));
+    if (idx_tensor != TensorId(unique_node->name(), 1)) return Status::OK();
+
+    // Input 0 (data) of the reduction node becomes input 1 (params) of the
+    // gather node.
+    reduction_node->set_input(0, gather_node->input(0));
+    ctx().node_map->UpdateInput(reduction_node->name(),
+                                reduction_node->input(0),
+                                gather_node->input(0));
+
+    // Input 1 (indices) of the reduction node becomes input 0 (x) of the unique
+    // node.
+    reduction_node->set_input(1, unique_node->input(0));
+    ctx().node_map->UpdateInput(reduction_node->name(),
+                                reduction_node->input(1),
+                                unique_node->input(0));
+    SetDataTypeToAttr(unique_element_type, "Tidx", reduction_node);
+
+    *simplified_node_name = reduction_node->name();
+    return Status::OK();
+  }
+
+ private:
+  bool IsAxis0(const NodeDef& node, int axis_input) {
+    Tensor axis_tensor;
+    if (!GetTensorFromConstNode(node.input(axis_input), &axis_tensor))
+      return false;
+    if (axis_tensor.NumElements() != 1) return false;
+    if (axis_tensor.dtype() == DT_INT32) {
+      return axis_tensor.flat<int32>()(0) == 0;
+    } else if (axis_tensor.dtype() == DT_INT64) {
+      return axis_tensor.flat<int64>()(0) == 0;
+    } else {
+      return false;
+    }
+  }
+};
+
 }  // namespace
 
 Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
@@ -3538,6 +3640,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
     pipeline.AddStage<RemoveStackSliceSameAxis>(ctx, ctx_ext);
   if (options_.fuse_squared_diff)
     pipeline.AddStage<FuseSquaredDiffStage>(ctx, ctx_ext);
+  if (options_.simplify_embedding_lookup)
+    pipeline.AddStage<SimplifyEmbeddingLookupStage>(ctx, ctx_ext);
 
   VLOG(1) << "Run " << pipeline.NumStages() << " arithmetic optimizer stages: "
           << absl::StrJoin(pipeline.StageNames(), ", ");
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
index 76aca8b840e..aa4762ff5c3 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
@@ -85,6 +85,7 @@ class ArithmeticOptimizer : public GraphOptimizer {
     bool convert_expm1 = true;
     bool unary_ops_composition = true;
     bool remove_stack_slice_same_axis = true;
+    bool simplify_embedding_lookup = true;
 
     // Choose which arithmetic optimizer stages will be enabled for a given
     // optimization level by default.
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index 50896b11923..8da306190a0 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -4085,5 +4085,49 @@ TEST_F(ArithmeticOptimizerTest, SimplifyAggregationBFloat16) {
   test::ExpectTensorEqual<bfloat16>(tensors[0], tensors_expected[0]);
 }
 
+TEST_F(ArithmeticOptimizerTest, SimplifyEmbeddingLookup) {
+  for (DataType unique_idx_type : {DT_INT32, DT_INT64}) {
+    tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+    Output embeddings = ops::Const(s.WithOpName("embeddings"),
+                                   {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
+    Output segment_ids =
+        ops::Const(s.WithOpName("segment_ids"), {0, 1, 1, 2, 2, 2, 2});
+    Output indices = ops::Const(s.WithOpName("indices"), {0, 0, 1, 0, 1, 0, 1});
+    auto unique = ops::Unique(s.WithOpName("unique"), indices,
+                              /*attrs=*/{unique_idx_type});
+    Output ids = unique.y;
+    Output idx = unique.idx;
+    Output gathered_rows =
+        ops::Gather(s.WithOpName("gathered_rows"), embeddings, ids);
+    Output result = ops::SparseSegmentSum(s.WithOpName("result"), gathered_rows,
+                                          idx, segment_ids);
+    Output id = ops::Identity(s.WithOpName("id"), result);
+
+    GrapplerItem item;
+    TF_CHECK_OK(s.ToGraphDef(&item.graph));
+    item.fetch = {"id"};
+    auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+    ASSERT_EQ(tensors_expected.size(), 1);
+
+    GraphDef output;
+    ArithmeticOptimizer optimizer;
+    EnableOnlySimplifyEmbeddingLookup(&optimizer);
+    OptimizeAndPrune(&optimizer, &item, &output);
+
+    for (const auto& node : output.node()) {
+      if (node.name() == "result") {
+        EXPECT_EQ(node.input(0), "embeddings");
+        EXPECT_EQ(node.input(1), "indices");
+      }
+      EXPECT_NE(node.op(), "Unique");
+      EXPECT_NE(node.op(), "Gather");
+    }
+
+    auto tensors = EvaluateNodes(output, item.fetch);
+    ASSERT_EQ(tensors.size(), 1);
+    test::ExpectTensorEqual<float>(tensors[0], tensors_expected[0]);
+  }
+}
+
 }  // namespace grappler
 }  // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test_utils.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test_utils.h
index 73bb5a0d97c..69b528fb446 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test_utils.h
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test_utils.h
@@ -223,6 +223,11 @@ class ArithmeticOptimizerTest : public GrapplerTest {
     optimizer->options_.remove_stack_slice_same_axis = true;
   }
 
+  void EnableOnlySimplifyEmbeddingLookup(ArithmeticOptimizer* optimizer) {
+    DisableAllStages(optimizer);
+    optimizer->options_.simplify_embedding_lookup = true;
+  }
+
  private:
   void DisableAllStages(ArithmeticOptimizer* optimizer) {
     ArithmeticOptimizer::ArithmeticOptimizerOptions options;
@@ -250,6 +255,7 @@ class ArithmeticOptimizerTest : public GrapplerTest {
     options.replace_mul_with_square = false;
     options.simplify_aggregation = false;
     options.unary_ops_composition = false;
+    options.simplify_embedding_lookup = false;
     optimizer->options_ = options;
   }
 };
diff --git a/tensorflow/python/feature_column/feature_column_test.py b/tensorflow/python/feature_column/feature_column_test.py
index b9206f40ba0..21def9cfa2c 100644
--- a/tensorflow/python/feature_column/feature_column_test.py
+++ b/tensorflow/python/feature_column/feature_column_test.py
@@ -4416,7 +4416,7 @@ class IdentityCategoricalColumnTest(test.TestCase):
 
     with _initialized_session():
       with self.assertRaisesRegexp(errors.OpError,
-                                   r'indices\[0\] = 2 is not in \[0, 2\)'):
+                                   r'indices\[0\] .* 2 .* \[0, 2\)'):
         self.evaluate(embedding_lookup)
 
   @test_util.run_deprecated_v1