Added GatherNd and StridedSlice to the estimator.

PiperOrigin-RevId: 325066413
Change-Id: I15ea90b8ad01d127d028756d7267193594bd15ac
This commit is contained in:
A. Unique TensorFlower 2020-08-05 11:49:51 -07:00 committed by TensorFlower Gardener
parent cf77a7186a
commit ef0f08f5dd
2 changed files with 56 additions and 20 deletions

View File

@ -73,6 +73,7 @@ constexpr char kSize[] = "Size";
constexpr char kStopGradient[] = "StopGradient"; constexpr char kStopGradient[] = "StopGradient";
constexpr char kPreventGradient[] = "PreventGradient"; constexpr char kPreventGradient[] = "PreventGradient";
constexpr char kGather[] = "Gather"; constexpr char kGather[] = "Gather";
constexpr char kGatherNd[] = "GatherNd";
constexpr char kGatherV2[] = "GatherV2"; constexpr char kGatherV2[] = "GatherV2";
constexpr char kScatterAdd[] = "ScatterAdd"; constexpr char kScatterAdd[] = "ScatterAdd";
constexpr char kScatterDiv[] = "ScatterDiv"; constexpr char kScatterDiv[] = "ScatterDiv";
@ -82,6 +83,7 @@ constexpr char kScatterMul[] = "ScatterMul";
constexpr char kScatterSub[] = "ScatterSub"; constexpr char kScatterSub[] = "ScatterSub";
constexpr char kScatterUpdate[] = "ScatterUpdate"; constexpr char kScatterUpdate[] = "ScatterUpdate";
constexpr char kSlice[] = "Slice"; constexpr char kSlice[] = "Slice";
constexpr char kStridedSlice[] = "StridedSlice";
constexpr char kSpaceToDepth[] = "SpaceToDepth"; constexpr char kSpaceToDepth[] = "SpaceToDepth";
constexpr char kTranspose[] = "Transpose"; constexpr char kTranspose[] = "Transpose";
constexpr char kMaxPool[] = "MaxPool"; constexpr char kMaxPool[] = "MaxPool";
@ -402,6 +404,8 @@ OpLevelCostEstimator::OpLevelCostEstimator() {
device_cost_impl_.emplace(kGather, device_cost_impl_.emplace(kGather,
wrap(&OpLevelCostEstimator::PredictGatherOrSlice)); wrap(&OpLevelCostEstimator::PredictGatherOrSlice));
device_cost_impl_.emplace(kGatherNd,
wrap(&OpLevelCostEstimator::PredictGatherOrSlice));
device_cost_impl_.emplace(kGatherV2, device_cost_impl_.emplace(kGatherV2,
wrap(&OpLevelCostEstimator::PredictGatherOrSlice)); wrap(&OpLevelCostEstimator::PredictGatherOrSlice));
device_cost_impl_.emplace(kScatterAdd, device_cost_impl_.emplace(kScatterAdd,
@ -421,6 +425,8 @@ OpLevelCostEstimator::OpLevelCostEstimator() {
device_cost_impl_.emplace(kSlice, device_cost_impl_.emplace(kSlice,
wrap(&OpLevelCostEstimator::PredictGatherOrSlice)); wrap(&OpLevelCostEstimator::PredictGatherOrSlice));
device_cost_impl_.emplace(kStridedSlice,
wrap(&OpLevelCostEstimator::PredictGatherOrSlice));
device_cost_impl_.emplace(kPlaceholder, device_cost_impl_.emplace(kPlaceholder,
wrap(&OpLevelCostEstimator::PredictIdentity)); wrap(&OpLevelCostEstimator::PredictIdentity));
@ -1799,15 +1805,20 @@ Costs OpLevelCostEstimator::PredictGatherOrSlice(
const double output_size = CalculateOutputSize(op_info, &unknown_shapes); const double output_size = CalculateOutputSize(op_info, &unknown_shapes);
double input_size = output_size; double input_size = output_size;
int begin_input_index = 1, end_input_index;
if (op_info.op() == "Slice") { if (op_info.op() == "Slice") {
// Add 'begin' & 'size' tensors sizes. // Slice: 'input' (omitted), 'begin', 'size'
input_size += end_input_index = 3;
CalculateTensorElementCount(op_info.inputs(1), &unknown_shapes) + } else if (op_info.op() == "StridedSlice") {
CalculateTensorElementCount(op_info.inputs(2), &unknown_shapes); // StridedSlice: 'input' (omitted), 'begin', 'end', 'strides'
end_input_index = 4;
} else { } else {
// Assuming this is "Gather" or "GatherV2" op, add 'indices' size. // Gather, GatherV2, GatherNd: 'params' (omitted), 'indices'
end_input_index = 2;
}
for (int i = begin_input_index; i < end_input_index; ++i) {
input_size += input_size +=
CalculateTensorElementCount(op_info.inputs(1), &unknown_shapes); CalculateTensorElementCount(op_info.inputs(i), &unknown_shapes);
} }
Costs costs = Costs costs =

View File

@ -641,22 +641,26 @@ TEST_F(OpLevelCostEstimatorTest, TestPersistentOpCosts) {
} }
TEST_F(OpLevelCostEstimatorTest, TestGatherCosts) { TEST_F(OpLevelCostEstimatorTest, TestGatherCosts) {
OpContext op_context; std::vector<std::string> gather_ops = {"Gather", "GatherNd", "GatherV2"};
SetCpuDevice(&op_context.op_info);
op_context.op_info.set_op("Gather");
// Huge first input shouldn't affect Gather execution and memory costs. for (const auto& op : gather_ops) {
DescribeArbitraryRankInput({10000000, 10}, DT_FLOAT, &op_context.op_info); OpContext op_context;
DescribeArbitraryRankInput({16}, DT_INT64, &op_context.op_info); SetCpuDevice(&op_context.op_info);
DescribeArbitraryRankOutput({16, 10}, DT_FLOAT, &op_context.op_info); op_context.op_info.set_op(op);
auto cost = estimator_.PredictCosts(op_context); // Huge first input shouldn't affect Gather execution and memory costs.
EXPECT_EQ(Costs::Duration(130), cost.memory_time); DescribeArbitraryRankInput({10000000, 10}, DT_FLOAT, &op_context.op_info);
EXPECT_EQ(Costs::Duration(16), cost.compute_time); DescribeArbitraryRankInput({16}, DT_INT64, &op_context.op_info);
EXPECT_EQ(Costs::Duration(146), cost.execution_time); DescribeArbitraryRankOutput({16, 10}, DT_FLOAT, &op_context.op_info);
EXPECT_EQ(1, cost.num_ops_total);
EXPECT_FALSE(cost.inaccurate); auto cost = estimator_.PredictCosts(op_context);
EXPECT_EQ(0, cost.num_ops_with_unknown_shapes); EXPECT_EQ(Costs::Duration(130), cost.memory_time);
EXPECT_EQ(Costs::Duration(16), cost.compute_time);
EXPECT_EQ(Costs::Duration(146), cost.execution_time);
EXPECT_EQ(1, cost.num_ops_total);
EXPECT_FALSE(cost.inaccurate);
EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
}
} }
TEST_F(OpLevelCostEstimatorTest, TestGatherCostsWithoutOutput) { TEST_F(OpLevelCostEstimatorTest, TestGatherCostsWithoutOutput) {
@ -697,6 +701,27 @@ TEST_F(OpLevelCostEstimatorTest, TestSliceCosts) {
EXPECT_EQ(0, cost.num_ops_with_unknown_shapes); EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
} }
TEST_F(OpLevelCostEstimatorTest, TestStridedSliceCosts) {
OpContext op_context;
SetCpuDevice(&op_context.op_info);
op_context.op_info.set_op("StridedSlice");
// Huge first input shouldn't affect StridedSlice execution and memory costs.
DescribeArbitraryRankInput({10000000, 10}, DT_FLOAT, &op_context.op_info);
DescribeArbitraryRankInput({2}, DT_INT64, &op_context.op_info);
DescribeArbitraryRankInput({2}, DT_INT64, &op_context.op_info);
DescribeArbitraryRankInput({2}, DT_INT64, &op_context.op_info);
DescribeArbitraryRankOutput({10, 10}, DT_FLOAT, &op_context.op_info);
auto cost = estimator_.PredictCosts(op_context);
EXPECT_EQ(Costs::Duration(81), cost.memory_time);
EXPECT_EQ(Costs::Duration(10), cost.compute_time);
EXPECT_EQ(Costs::Duration(91), cost.execution_time);
EXPECT_EQ(1, cost.num_ops_total);
EXPECT_FALSE(cost.inaccurate);
EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
}
TEST_F(OpLevelCostEstimatorTest, TestScatterOps) { TEST_F(OpLevelCostEstimatorTest, TestScatterOps) {
std::vector<string> scatter_ops = {"ScatterAdd", "ScatterDiv", "ScatterMax", std::vector<string> scatter_ops = {"ScatterAdd", "ScatterDiv", "ScatterMax",
"ScatterMin", "ScatterMul", "ScatterSub", "ScatterMin", "ScatterMul", "ScatterSub",