Added GatherNd and StridedSlice to the estimator.

PiperOrigin-RevId: 325066413
Change-Id: I15ea90b8ad01d127d028756d7267193594bd15ac
This commit is contained in:
A. Unique TensorFlower 2020-08-05 11:49:51 -07:00 committed by TensorFlower Gardener
parent cf77a7186a
commit ef0f08f5dd
2 changed files with 56 additions and 20 deletions

View File

@ -73,6 +73,7 @@ constexpr char kSize[] = "Size";
constexpr char kStopGradient[] = "StopGradient";
constexpr char kPreventGradient[] = "PreventGradient";
constexpr char kGather[] = "Gather";
constexpr char kGatherNd[] = "GatherNd";
constexpr char kGatherV2[] = "GatherV2";
constexpr char kScatterAdd[] = "ScatterAdd";
constexpr char kScatterDiv[] = "ScatterDiv";
@ -82,6 +83,7 @@ constexpr char kScatterMul[] = "ScatterMul";
constexpr char kScatterSub[] = "ScatterSub";
constexpr char kScatterUpdate[] = "ScatterUpdate";
constexpr char kSlice[] = "Slice";
constexpr char kStridedSlice[] = "StridedSlice";
constexpr char kSpaceToDepth[] = "SpaceToDepth";
constexpr char kTranspose[] = "Transpose";
constexpr char kMaxPool[] = "MaxPool";
@ -402,6 +404,8 @@ OpLevelCostEstimator::OpLevelCostEstimator() {
device_cost_impl_.emplace(kGather,
wrap(&OpLevelCostEstimator::PredictGatherOrSlice));
device_cost_impl_.emplace(kGatherNd,
wrap(&OpLevelCostEstimator::PredictGatherOrSlice));
device_cost_impl_.emplace(kGatherV2,
wrap(&OpLevelCostEstimator::PredictGatherOrSlice));
device_cost_impl_.emplace(kScatterAdd,
@ -421,6 +425,8 @@ OpLevelCostEstimator::OpLevelCostEstimator() {
device_cost_impl_.emplace(kSlice,
wrap(&OpLevelCostEstimator::PredictGatherOrSlice));
device_cost_impl_.emplace(kStridedSlice,
wrap(&OpLevelCostEstimator::PredictGatherOrSlice));
device_cost_impl_.emplace(kPlaceholder,
wrap(&OpLevelCostEstimator::PredictIdentity));
@ -1799,15 +1805,20 @@ Costs OpLevelCostEstimator::PredictGatherOrSlice(
const double output_size = CalculateOutputSize(op_info, &unknown_shapes);
double input_size = output_size;
int begin_input_index = 1, end_input_index;
if (op_info.op() == "Slice") {
// Add 'begin' & 'size' tensors sizes.
input_size +=
CalculateTensorElementCount(op_info.inputs(1), &unknown_shapes) +
CalculateTensorElementCount(op_info.inputs(2), &unknown_shapes);
// Slice: 'input' (omitted), 'begin', 'size'
end_input_index = 3;
} else if (op_info.op() == "StridedSlice") {
// StridedSlice: 'input' (omitted), 'begin', 'end', 'strides'
end_input_index = 4;
} else {
// Assuming this is "Gather" or "GatherV2" op, add 'indices' size.
// Gather, GatherV2, GatherNd: 'params' (omitted), 'indices'
end_input_index = 2;
}
for (int i = begin_input_index; i < end_input_index; ++i) {
input_size +=
CalculateTensorElementCount(op_info.inputs(1), &unknown_shapes);
CalculateTensorElementCount(op_info.inputs(i), &unknown_shapes);
}
Costs costs =

View File

@ -641,22 +641,26 @@ TEST_F(OpLevelCostEstimatorTest, TestPersistentOpCosts) {
}
TEST_F(OpLevelCostEstimatorTest, TestGatherCosts) {
OpContext op_context;
SetCpuDevice(&op_context.op_info);
op_context.op_info.set_op("Gather");
std::vector<std::string> gather_ops = {"Gather", "GatherNd", "GatherV2"};
// Huge first input shouldn't affect Gather execution and memory costs.
DescribeArbitraryRankInput({10000000, 10}, DT_FLOAT, &op_context.op_info);
DescribeArbitraryRankInput({16}, DT_INT64, &op_context.op_info);
DescribeArbitraryRankOutput({16, 10}, DT_FLOAT, &op_context.op_info);
for (const auto& op : gather_ops) {
OpContext op_context;
SetCpuDevice(&op_context.op_info);
op_context.op_info.set_op(op);
auto cost = estimator_.PredictCosts(op_context);
EXPECT_EQ(Costs::Duration(130), cost.memory_time);
EXPECT_EQ(Costs::Duration(16), cost.compute_time);
EXPECT_EQ(Costs::Duration(146), cost.execution_time);
EXPECT_EQ(1, cost.num_ops_total);
EXPECT_FALSE(cost.inaccurate);
EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
// Huge first input shouldn't affect Gather execution and memory costs.
DescribeArbitraryRankInput({10000000, 10}, DT_FLOAT, &op_context.op_info);
DescribeArbitraryRankInput({16}, DT_INT64, &op_context.op_info);
DescribeArbitraryRankOutput({16, 10}, DT_FLOAT, &op_context.op_info);
auto cost = estimator_.PredictCosts(op_context);
EXPECT_EQ(Costs::Duration(130), cost.memory_time);
EXPECT_EQ(Costs::Duration(16), cost.compute_time);
EXPECT_EQ(Costs::Duration(146), cost.execution_time);
EXPECT_EQ(1, cost.num_ops_total);
EXPECT_FALSE(cost.inaccurate);
EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
}
}
TEST_F(OpLevelCostEstimatorTest, TestGatherCostsWithoutOutput) {
@ -697,6 +701,27 @@ TEST_F(OpLevelCostEstimatorTest, TestSliceCosts) {
EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
}
TEST_F(OpLevelCostEstimatorTest, TestStridedSliceCosts) {
OpContext op_context;
SetCpuDevice(&op_context.op_info);
op_context.op_info.set_op("StridedSlice");
// Huge first input shouldn't affect StridedSlice execution and memory costs.
DescribeArbitraryRankInput({10000000, 10}, DT_FLOAT, &op_context.op_info);
DescribeArbitraryRankInput({2}, DT_INT64, &op_context.op_info);
DescribeArbitraryRankInput({2}, DT_INT64, &op_context.op_info);
DescribeArbitraryRankInput({2}, DT_INT64, &op_context.op_info);
DescribeArbitraryRankOutput({10, 10}, DT_FLOAT, &op_context.op_info);
auto cost = estimator_.PredictCosts(op_context);
EXPECT_EQ(Costs::Duration(81), cost.memory_time);
EXPECT_EQ(Costs::Duration(10), cost.compute_time);
EXPECT_EQ(Costs::Duration(91), cost.execution_time);
EXPECT_EQ(1, cost.num_ops_total);
EXPECT_FALSE(cost.inaccurate);
EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
}
TEST_F(OpLevelCostEstimatorTest, TestScatterOps) {
std::vector<string> scatter_ops = {"ScatterAdd", "ScatterDiv", "ScatterMax",
"ScatterMin", "ScatterMul", "ScatterSub",