diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc index ed86e92a2e7..d76ff4359c1 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc @@ -73,6 +73,7 @@ constexpr char kSize[] = "Size"; constexpr char kStopGradient[] = "StopGradient"; constexpr char kPreventGradient[] = "PreventGradient"; constexpr char kGather[] = "Gather"; +constexpr char kGatherNd[] = "GatherNd"; constexpr char kGatherV2[] = "GatherV2"; constexpr char kScatterAdd[] = "ScatterAdd"; constexpr char kScatterDiv[] = "ScatterDiv"; @@ -82,6 +83,7 @@ constexpr char kScatterMul[] = "ScatterMul"; constexpr char kScatterSub[] = "ScatterSub"; constexpr char kScatterUpdate[] = "ScatterUpdate"; constexpr char kSlice[] = "Slice"; +constexpr char kStridedSlice[] = "StridedSlice"; constexpr char kSpaceToDepth[] = "SpaceToDepth"; constexpr char kTranspose[] = "Transpose"; constexpr char kMaxPool[] = "MaxPool"; @@ -402,6 +404,8 @@ OpLevelCostEstimator::OpLevelCostEstimator() { device_cost_impl_.emplace(kGather, wrap(&OpLevelCostEstimator::PredictGatherOrSlice)); + device_cost_impl_.emplace(kGatherNd, + wrap(&OpLevelCostEstimator::PredictGatherOrSlice)); device_cost_impl_.emplace(kGatherV2, wrap(&OpLevelCostEstimator::PredictGatherOrSlice)); device_cost_impl_.emplace(kScatterAdd, @@ -421,6 +425,8 @@ OpLevelCostEstimator::OpLevelCostEstimator() { device_cost_impl_.emplace(kSlice, wrap(&OpLevelCostEstimator::PredictGatherOrSlice)); + device_cost_impl_.emplace(kStridedSlice, + wrap(&OpLevelCostEstimator::PredictGatherOrSlice)); device_cost_impl_.emplace(kPlaceholder, wrap(&OpLevelCostEstimator::PredictIdentity)); @@ -1799,15 +1805,20 @@ Costs OpLevelCostEstimator::PredictGatherOrSlice( const double output_size = CalculateOutputSize(op_info, &unknown_shapes); double input_size = output_size; + int begin_input_index = 1, end_input_index; if (op_info.op() == "Slice") { - // Add 'begin' & 'size' tensors sizes. - input_size += - CalculateTensorElementCount(op_info.inputs(1), &unknown_shapes) + - CalculateTensorElementCount(op_info.inputs(2), &unknown_shapes); + // Slice: 'input' (omitted), 'begin', 'size' + end_input_index = 3; + } else if (op_info.op() == "StridedSlice") { + // StridedSlice: 'input' (omitted), 'begin', 'end', 'strides' + end_input_index = 4; } else { - // Assuming this is "Gather" or "GatherV2" op, add 'indices' size. + // Gather, GatherV2, GatherNd: 'params' (omitted), 'indices' + end_input_index = 2; + } + for (int i = begin_input_index; i < end_input_index; ++i) { input_size += - CalculateTensorElementCount(op_info.inputs(1), &unknown_shapes); + CalculateTensorElementCount(op_info.inputs(i), &unknown_shapes); } Costs costs = diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc index 90f3e969df9..c5209753a90 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc @@ -641,22 +641,26 @@ TEST_F(OpLevelCostEstimatorTest, TestPersistentOpCosts) { } TEST_F(OpLevelCostEstimatorTest, TestGatherCosts) { - OpContext op_context; - SetCpuDevice(&op_context.op_info); - op_context.op_info.set_op("Gather"); + std::vector gather_ops = {"Gather", "GatherNd", "GatherV2"}; - // Huge first input shouldn't affect Gather execution and memory costs. - DescribeArbitraryRankInput({10000000, 10}, DT_FLOAT, &op_context.op_info); - DescribeArbitraryRankInput({16}, DT_INT64, &op_context.op_info); - DescribeArbitraryRankOutput({16, 10}, DT_FLOAT, &op_context.op_info); + for (const auto& op : gather_ops) { + OpContext op_context; + SetCpuDevice(&op_context.op_info); + op_context.op_info.set_op(op); - auto cost = estimator_.PredictCosts(op_context); - EXPECT_EQ(Costs::Duration(130), cost.memory_time); - EXPECT_EQ(Costs::Duration(16), cost.compute_time); - EXPECT_EQ(Costs::Duration(146), cost.execution_time); - EXPECT_EQ(1, cost.num_ops_total); - EXPECT_FALSE(cost.inaccurate); - EXPECT_EQ(0, cost.num_ops_with_unknown_shapes); + // Huge first input shouldn't affect Gather execution and memory costs. + DescribeArbitraryRankInput({10000000, 10}, DT_FLOAT, &op_context.op_info); + DescribeArbitraryRankInput({16}, DT_INT64, &op_context.op_info); + DescribeArbitraryRankOutput({16, 10}, DT_FLOAT, &op_context.op_info); + + auto cost = estimator_.PredictCosts(op_context); + EXPECT_EQ(Costs::Duration(130), cost.memory_time); + EXPECT_EQ(Costs::Duration(16), cost.compute_time); + EXPECT_EQ(Costs::Duration(146), cost.execution_time); + EXPECT_EQ(1, cost.num_ops_total); + EXPECT_FALSE(cost.inaccurate); + EXPECT_EQ(0, cost.num_ops_with_unknown_shapes); + } } TEST_F(OpLevelCostEstimatorTest, TestGatherCostsWithoutOutput) { @@ -697,6 +701,27 @@ TEST_F(OpLevelCostEstimatorTest, TestSliceCosts) { EXPECT_EQ(0, cost.num_ops_with_unknown_shapes); } +TEST_F(OpLevelCostEstimatorTest, TestStridedSliceCosts) { + OpContext op_context; + SetCpuDevice(&op_context.op_info); + op_context.op_info.set_op("StridedSlice"); + + // Huge first input shouldn't affect StridedSlice execution and memory costs. + DescribeArbitraryRankInput({10000000, 10}, DT_FLOAT, &op_context.op_info); + DescribeArbitraryRankInput({2}, DT_INT64, &op_context.op_info); + DescribeArbitraryRankInput({2}, DT_INT64, &op_context.op_info); + DescribeArbitraryRankInput({2}, DT_INT64, &op_context.op_info); + DescribeArbitraryRankOutput({10, 10}, DT_FLOAT, &op_context.op_info); + + auto cost = estimator_.PredictCosts(op_context); + EXPECT_EQ(Costs::Duration(81), cost.memory_time); + EXPECT_EQ(Costs::Duration(10), cost.compute_time); + EXPECT_EQ(Costs::Duration(91), cost.execution_time); + EXPECT_EQ(1, cost.num_ops_total); + EXPECT_FALSE(cost.inaccurate); + EXPECT_EQ(0, cost.num_ops_with_unknown_shapes); +} + TEST_F(OpLevelCostEstimatorTest, TestScatterOps) { std::vector scatter_ops = {"ScatterAdd", "ScatterDiv", "ScatterMax", "ScatterMin", "ScatterMul", "ScatterSub",