From 3b336d3173d5c8e84531bd97cc0e452ac9a56a7e Mon Sep 17 00:00:00 2001 From: Derek Murray <mrry@google.com> Date: Fri, 24 Apr 2020 15:30:26 -0700 Subject: [PATCH] [Grappler] Add arithmetic optimizer stage for optimizing `tf.nn.embedding_lookup_sparse()`. This optimization eliminates unnecessary `tf.unique()` and `tf.gather()` operations from `tf.nn.embedding_lookup_sparse()` when the embeddings are unpartitioned (e.g. at inference time) and weights are not used. Instead, the `tf.sparse.segment_<combiner>()` operation is applied directly to the embedding matrix without uniquifying the IDs. PiperOrigin-RevId: 308338718 Change-Id: I6e689610a8f4f3dd0a3e8af77cce609ac6d4f9f9 --- tensorflow/core/grappler/op_types.cc | 19 ++++ tensorflow/core/grappler/op_types.h | 3 + .../optimizers/arithmetic_optimizer.cc | 104 ++++++++++++++++++ .../optimizers/arithmetic_optimizer.h | 1 + .../optimizers/arithmetic_optimizer_test.cc | 44 ++++++++ .../arithmetic_optimizer_test_utils.h | 6 + .../feature_column/feature_column_test.py | 2 +- 7 files changed, 178 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc index 8b8df527041..9e3b401154a 100644 --- a/tensorflow/core/grappler/op_types.cc +++ b/tensorflow/core/grappler/op_types.cc @@ -76,6 +76,15 @@ bool IsAnyMin(const NodeDef& node) { return op == "Min" || op == "SegmentMin" || op == "UnsortedSegmentMin"; } +bool IsAnySparseSegmentReduction(const NodeDef& node) { + const auto& op = node.op(); + return op == "SparseSegmentSum" || op == "SparseSegmentSumWithNumSegments" || + op == "SparseSegmentMean" || + op == "SparseSegmentMeanWithNumSegments" || + op == "SparseSegmentSqrtN" || + op == "SparseSegmentSqrtNWithNumSegments"; +} + bool IsApproximateEqual(const NodeDef& node) { return node.op() == "ApproximateEqual"; } @@ -268,6 +277,11 @@ bool IsFusedBatchNormGrad(const NodeDef& node) { op == "FusedBatchNormGradV3"; } +bool IsGather(const NodeDef& node) { + const auto& op = node.op(); + return op == "Gather" || op == "GatherV2"; +} + bool IsGreater(const NodeDef& node) { return node.op() == "Greater"; } bool IsGreaterEqual(const NodeDef& node) { return node.op() == "GreaterEqual"; } @@ -589,6 +603,11 @@ bool IsTruncateDiv(const NodeDef& node) { return node.op() == "TruncateDiv"; } bool IsTruncateMod(const NodeDef& node) { return node.op() == "TruncateMod"; } +bool IsUnique(const NodeDef& node) { + const auto& op = node.op(); + return op == "Unique" || op == "UniqueV2"; +} + bool IsUnpack(const NodeDef& node) { return node.op() == "Unpack"; } bool IsVariable(const NodeDef& node) { diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h index 88d81c5b202..b1624ac70c6 100644 --- a/tensorflow/core/grappler/op_types.h +++ b/tensorflow/core/grappler/op_types.h @@ -34,6 +34,7 @@ bool IsAnyMax(const NodeDef& node); bool IsAnyMaxPool(const NodeDef& node); bool IsAnyMin(const NodeDef& node); bool IsAnyMul(const NodeDef& node); +bool IsAnySparseSegmentReduction(const NodeDef& node); bool IsApproximateEqual(const NodeDef& node); bool IsArg(const NodeDef& node); bool IsArgMax(const NodeDef& node); @@ -81,6 +82,7 @@ bool IsFloorMod(const NodeDef& node); bool IsFusedBatchNorm(const NodeDef& node); bool IsFusedBatchNormEx(const NodeDef& node); bool IsFusedBatchNormGrad(const NodeDef& node); +bool IsGather(const NodeDef& node); bool IsGreater(const NodeDef& node); bool IsGreaterEqual(const NodeDef& node); bool IsHistogramSummary(const NodeDef& node); @@ -187,6 +189,7 @@ bool IsTile(const NodeDef& node); bool IsTranspose(const NodeDef& node); bool IsTruncateDiv(const NodeDef& node); bool IsTruncateMod(const NodeDef& node); +bool IsUnique(const NodeDef& node); bool IsUnpack(const NodeDef& node); bool IsVariable(const NodeDef& node); bool IsWhile(const NodeDef& node); diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index 200572456c3..b9502ffb45e 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -3466,6 +3466,108 @@ class RemoveStackSliceSameAxis : public ArithmeticOptimizerStage { } }; +// Eliminates unnecessary copies during sparse embedding lookup operations. +// +// For non-partitioned variables, the `tf.nn.embedding_lookup_sparse()` function +// generates code of the form: +// +// embeddings = <a 2D Tensor> +// sparse_ids = <a tf.int64 SparseTensor> +// segment_ids = sparse_ids.indices[:, 0] +// ids, idx = tf.unique(sparse_ids.values) +// gathered_rows = tf.gather(params, ids) +// result = tf.sparse.segment_<combiner>(gathered_rows, idx, segment_ids) +// +// In this case, all of the work in `tf.unique()` and `tf.gather()` +// can be avoided by passing the full embeddings to +// `tf.sparse.segment_<combiner>()` and performing the same amount of +// computation (but fewer copies and allocations) as follows: +// +// embeddings = <a 2D Tensor> +// sparse_ids = <a tf.int64 SparseTensor> +// segment_ids = sparse_ids.indices[:, 0] +// result = tf.sparse.segment_<combiner>( +// embeddings, sparse_ids.values, segment_ids) +class SimplifyEmbeddingLookupStage : public ArithmeticOptimizerStage { + public: + explicit SimplifyEmbeddingLookupStage( + const GraphOptimizerContext& ctx, + const ArithmeticOptimizerContext& ctx_ext) + : ArithmeticOptimizerStage("SimplifyEmbeddingLookupStage", ctx, ctx_ext) { + } + ~SimplifyEmbeddingLookupStage() override = default; + + bool IsSupported(const NodeDef* node) const override { + return IsAnySparseSegmentReduction(*node); + } + + Status TrySimplify(NodeDef* reduction_node, + string* simplified_node_name) override { + if (IsInPreserveSet(*reduction_node)) return Status::OK(); + + // Input 0 (data) of the reduction node must be a tf.gather() on the 0th + // axis. + NodeDef* gather_node = nullptr; + TF_RETURN_IF_ERROR(GetInputNode(reduction_node->input(0), &gather_node)); + if (!IsGather(*gather_node) || IsInPreserveSet(*gather_node) || + gather_node->device() != reduction_node->device()) + return Status::OK(); + if (gather_node->op() == "GatherV2" && !IsAxis0(*gather_node, 2)) + return Status::OK(); + + // Input 1 (indices) of the gather node must be a tf.unique() on the 0th + // axis. + NodeDef* unique_node = nullptr; + TF_RETURN_IF_ERROR(GetInputNode(gather_node->input(1), &unique_node)); + if (!IsUnique(*unique_node) || IsInPreserveSet(*unique_node) || + unique_node->device() != gather_node->device()) + return Status::OK(); + if (unique_node->op() == "UniqueV2" && !IsAxis0(*unique_node, 1)) + return Status::OK(); + + DataType unique_element_type; + TF_RETURN_IF_ERROR(GetNodeAttr(*unique_node, "T", &unique_element_type)); + + // Input 1 (indices) of the reduction node must be output 1 of the unique + // node. + const TensorId idx_tensor = ParseTensorName(reduction_node->input(1)); + if (idx_tensor != TensorId(unique_node->name(), 1)) return Status::OK(); + + // Input 0 (data) of the reduction node becomes input 1 (params) of the + // gather node. + reduction_node->set_input(0, gather_node->input(0)); + ctx().node_map->UpdateInput(reduction_node->name(), + reduction_node->input(0), + gather_node->input(0)); + + // Input 1 (indices) of the reduction node becomes input 0 (x) of the unique + // node. + reduction_node->set_input(1, unique_node->input(0)); + ctx().node_map->UpdateInput(reduction_node->name(), + reduction_node->input(1), + unique_node->input(0)); + SetDataTypeToAttr(unique_element_type, "Tidx", reduction_node); + + *simplified_node_name = reduction_node->name(); + return Status::OK(); + } + + private: + bool IsAxis0(const NodeDef& node, int axis_input) { + Tensor axis_tensor; + if (!GetTensorFromConstNode(node.input(axis_input), &axis_tensor)) + return false; + if (axis_tensor.NumElements() != 1) return false; + if (axis_tensor.dtype() == DT_INT32) { + return axis_tensor.flat<int32>()(0) == 0; + } else if (axis_tensor.dtype() == DT_INT64) { + return axis_tensor.flat<int64>()(0) == 0; + } else { + return false; + } + } +}; + } // namespace Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) { @@ -3538,6 +3640,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) { pipeline.AddStage<RemoveStackSliceSameAxis>(ctx, ctx_ext); if (options_.fuse_squared_diff) pipeline.AddStage<FuseSquaredDiffStage>(ctx, ctx_ext); + if (options_.simplify_embedding_lookup) + pipeline.AddStage<SimplifyEmbeddingLookupStage>(ctx, ctx_ext); VLOG(1) << "Run " << pipeline.NumStages() << " arithmetic optimizer stages: " << absl::StrJoin(pipeline.StageNames(), ", "); diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h index 76aca8b840e..aa4762ff5c3 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h @@ -85,6 +85,7 @@ class ArithmeticOptimizer : public GraphOptimizer { bool convert_expm1 = true; bool unary_ops_composition = true; bool remove_stack_slice_same_axis = true; + bool simplify_embedding_lookup = true; // Choose which arithmetic optimizer stages will be enabled for a given // optimization level by default. diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index 50896b11923..8da306190a0 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -4085,5 +4085,49 @@ TEST_F(ArithmeticOptimizerTest, SimplifyAggregationBFloat16) { test::ExpectTensorEqual<bfloat16>(tensors[0], tensors_expected[0]); } +TEST_F(ArithmeticOptimizerTest, SimplifyEmbeddingLookup) { + for (DataType unique_idx_type : {DT_INT32, DT_INT64}) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output embeddings = ops::Const(s.WithOpName("embeddings"), + {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2}); + Output segment_ids = + ops::Const(s.WithOpName("segment_ids"), {0, 1, 1, 2, 2, 2, 2}); + Output indices = ops::Const(s.WithOpName("indices"), {0, 0, 1, 0, 1, 0, 1}); + auto unique = ops::Unique(s.WithOpName("unique"), indices, + /*attrs=*/{unique_idx_type}); + Output ids = unique.y; + Output idx = unique.idx; + Output gathered_rows = + ops::Gather(s.WithOpName("gathered_rows"), embeddings, ids); + Output result = ops::SparseSegmentSum(s.WithOpName("result"), gathered_rows, + idx, segment_ids); + Output id = ops::Identity(s.WithOpName("id"), result); + + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + item.fetch = {"id"}; + auto tensors_expected = EvaluateNodes(item.graph, item.fetch); + ASSERT_EQ(tensors_expected.size(), 1); + + GraphDef output; + ArithmeticOptimizer optimizer; + EnableOnlySimplifyEmbeddingLookup(&optimizer); + OptimizeAndPrune(&optimizer, &item, &output); + + for (const auto& node : output.node()) { + if (node.name() == "result") { + EXPECT_EQ(node.input(0), "embeddings"); + EXPECT_EQ(node.input(1), "indices"); + } + EXPECT_NE(node.op(), "Unique"); + EXPECT_NE(node.op(), "Gather"); + } + + auto tensors = EvaluateNodes(output, item.fetch); + ASSERT_EQ(tensors.size(), 1); + test::ExpectTensorEqual<float>(tensors[0], tensors_expected[0]); + } +} + } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test_utils.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test_utils.h index 73bb5a0d97c..69b528fb446 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test_utils.h +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test_utils.h @@ -223,6 +223,11 @@ class ArithmeticOptimizerTest : public GrapplerTest { optimizer->options_.remove_stack_slice_same_axis = true; } + void EnableOnlySimplifyEmbeddingLookup(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.simplify_embedding_lookup = true; + } + private: void DisableAllStages(ArithmeticOptimizer* optimizer) { ArithmeticOptimizer::ArithmeticOptimizerOptions options; @@ -250,6 +255,7 @@ class ArithmeticOptimizerTest : public GrapplerTest { options.replace_mul_with_square = false; options.simplify_aggregation = false; options.unary_ops_composition = false; + options.simplify_embedding_lookup = false; optimizer->options_ = options; } }; diff --git a/tensorflow/python/feature_column/feature_column_test.py b/tensorflow/python/feature_column/feature_column_test.py index b9206f40ba0..21def9cfa2c 100644 --- a/tensorflow/python/feature_column/feature_column_test.py +++ b/tensorflow/python/feature_column/feature_column_test.py @@ -4416,7 +4416,7 @@ class IdentityCategoricalColumnTest(test.TestCase): with _initialized_session(): with self.assertRaisesRegexp(errors.OpError, - r'indices\[0\] = 2 is not in \[0, 2\)'): + r'indices\[0\] .* 2 .* \[0, 2\)'): self.evaluate(embedding_lookup) @test_util.run_deprecated_v1