diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index 55f84601a59..8be53aa08e3 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -3571,6 +3571,65 @@ class SimplifyEmbeddingLookupStage : public ArithmeticOptimizerStage { } }; +// Eliminates unnecessary casts before sparse segment reduction operations. +// +// Existing graphs and library code would often insert a cast from DT_INT64 to +// DT_INT32 on the indices and/or segment_ids inputs to "SparseSegment*" ops. +// Support for for DT_INT64 indices and/or segment_ids now exists, so we can +// pass the input directly without a cast. +class RemoveCastIntoSegmentReductionStage : public ArithmeticOptimizerStage { + public: + explicit RemoveCastIntoSegmentReductionStage( + const GraphOptimizerContext& ctx, + const ArithmeticOptimizerContext& ctx_ext) + : ArithmeticOptimizerStage("RemoveCastIntoSegmentReductionStage", ctx, + ctx_ext) {} + ~RemoveCastIntoSegmentReductionStage() override = default; + + bool IsSupported(const NodeDef* node) const override { + return IsAnySparseSegmentReduction(*node); + } + + Status TrySimplify(NodeDef* reduction_node, + string* simplified_node_name) override { + if (IsInPreserveSet(*reduction_node)) return Status::OK(); + + bool optimized = false; + + // Inputs 1 (indices) and 2 (segment_ids) can be either DT_INT32 or + // DT_INT64. + std::array, 2> input_details = { + std::make_pair(1, "Tidx"), std::make_pair(2, "Tsegmentids")}; + + for (const auto& input : input_details) { + int input_index = input.first; + const string& type_attr_name = input.second; + NodeDef* cast_node = nullptr; + TF_RETURN_IF_ERROR( + GetInputNode(reduction_node->input(input_index), &cast_node)); + DataType original_index_type; + if (IsCastFromSupportedType(*cast_node, &original_index_type)) { + reduction_node->set_input(input_index, cast_node->input(0)); + ctx().node_map->UpdateInput(reduction_node->name(), + reduction_node->input(1), + cast_node->input(0)); + SetDataTypeToAttr(original_index_type, type_attr_name, reduction_node); + optimized = true; + } + } + + if (optimized) *simplified_node_name = reduction_node->name(); + return Status::OK(); + } + + private: + bool IsCastFromSupportedType(const NodeDef& node, DataType* out_input_type) { + if (!IsCast(node)) return false; + if (!GetNodeAttr(node, "SrcT", out_input_type).ok()) return false; + return *out_input_type == DT_INT32 || *out_input_type == DT_INT64; + } +}; + } // namespace Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) { @@ -3645,6 +3704,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) { pipeline.AddStage(ctx, ctx_ext); if (options_.simplify_embedding_lookup) pipeline.AddStage(ctx, ctx_ext); + if (options_.remove_cast_into_segment_reduction) + pipeline.AddStage(ctx, ctx_ext); VLOG(1) << "Run " << pipeline.NumStages() << " arithmetic optimizer stages: " << absl::StrJoin(pipeline.StageNames(), ", "); diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h index aa4762ff5c3..044dc855244 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h @@ -86,6 +86,7 @@ class ArithmeticOptimizer : public GraphOptimizer { bool unary_ops_composition = true; bool remove_stack_slice_same_axis = true; bool simplify_embedding_lookup = true; + bool remove_cast_into_segment_reduction = true; // Choose which arithmetic optimizer stages will be enabled for a given // optimization level by default. diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index 8da306190a0..477c284d44c 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -4129,5 +4129,49 @@ TEST_F(ArithmeticOptimizerTest, SimplifyEmbeddingLookup) { } } +TEST_F(ArithmeticOptimizerTest, RemoveCastIntoSegmentReduction) { + for (DataType indices_type : {DT_INT32, DT_INT64}) { + for (DataType segment_ids_type : {DT_INT32, DT_INT64}) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output embeddings = ops::Const(s.WithOpName("embeddings"), + {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2}); + Output indices = + ops::Cast(s.WithOpName("cast_indices"), + ops::Const(s.WithOpName("indices"), {0, 0, 1, 0, 1, 0, 1}), + indices_type); + Output segment_ids = ops::Cast( + s.WithOpName("cast_segment_ids"), + ops::Const(s.WithOpName("segment_ids"), {0, 1, 1, 2, 2, 2, 2}), + segment_ids_type); + Output result = ops::SparseSegmentSum(s.WithOpName("result"), embeddings, + indices, segment_ids); + Output id = ops::Identity(s.WithOpName("id"), result); + + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + item.fetch = {"id"}; + auto tensors_expected = EvaluateNodes(item.graph, item.fetch); + ASSERT_EQ(tensors_expected.size(), 1); + + GraphDef output; + ArithmeticOptimizer optimizer; + EnableOnlyRemoveCastIntoSegmentReduction(&optimizer); + OptimizeAndPrune(&optimizer, &item, &output); + + for (const auto& node : output.node()) { + if (node.name() == "result") { + EXPECT_EQ(node.input(1), "indices"); + EXPECT_EQ(node.input(2), "segment_ids"); + } + EXPECT_NE(node.op(), "Cast"); + } + + auto tensors = EvaluateNodes(output, item.fetch); + ASSERT_EQ(tensors.size(), 1); + test::ExpectTensorEqual(tensors[0], tensors_expected[0]); + } + } +} + } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test_utils.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test_utils.h index 69b528fb446..9025635e668 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test_utils.h +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test_utils.h @@ -228,6 +228,12 @@ class ArithmeticOptimizerTest : public GrapplerTest { optimizer->options_.simplify_embedding_lookup = true; } + void EnableOnlyRemoveCastIntoSegmentReduction( + ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.remove_cast_into_segment_reduction = true; + } + private: void DisableAllStages(ArithmeticOptimizer* optimizer) { ArithmeticOptimizer::ArithmeticOptimizerOptions options; @@ -256,6 +262,7 @@ class ArithmeticOptimizerTest : public GrapplerTest { options.simplify_aggregation = false; options.unary_ops_composition = false; options.simplify_embedding_lookup = false; + options.remove_cast_into_segment_reduction = false; optimizer->options_ = options; } };