Added GatherNd and StridedSlice to the estimator.
PiperOrigin-RevId: 325066413 Change-Id: I15ea90b8ad01d127d028756d7267193594bd15ac
This commit is contained in:
parent
cf77a7186a
commit
ef0f08f5dd
@ -73,6 +73,7 @@ constexpr char kSize[] = "Size";
|
||||
constexpr char kStopGradient[] = "StopGradient";
|
||||
constexpr char kPreventGradient[] = "PreventGradient";
|
||||
constexpr char kGather[] = "Gather";
|
||||
constexpr char kGatherNd[] = "GatherNd";
|
||||
constexpr char kGatherV2[] = "GatherV2";
|
||||
constexpr char kScatterAdd[] = "ScatterAdd";
|
||||
constexpr char kScatterDiv[] = "ScatterDiv";
|
||||
@ -82,6 +83,7 @@ constexpr char kScatterMul[] = "ScatterMul";
|
||||
constexpr char kScatterSub[] = "ScatterSub";
|
||||
constexpr char kScatterUpdate[] = "ScatterUpdate";
|
||||
constexpr char kSlice[] = "Slice";
|
||||
constexpr char kStridedSlice[] = "StridedSlice";
|
||||
constexpr char kSpaceToDepth[] = "SpaceToDepth";
|
||||
constexpr char kTranspose[] = "Transpose";
|
||||
constexpr char kMaxPool[] = "MaxPool";
|
||||
@ -402,6 +404,8 @@ OpLevelCostEstimator::OpLevelCostEstimator() {
|
||||
|
||||
device_cost_impl_.emplace(kGather,
|
||||
wrap(&OpLevelCostEstimator::PredictGatherOrSlice));
|
||||
device_cost_impl_.emplace(kGatherNd,
|
||||
wrap(&OpLevelCostEstimator::PredictGatherOrSlice));
|
||||
device_cost_impl_.emplace(kGatherV2,
|
||||
wrap(&OpLevelCostEstimator::PredictGatherOrSlice));
|
||||
device_cost_impl_.emplace(kScatterAdd,
|
||||
@ -421,6 +425,8 @@ OpLevelCostEstimator::OpLevelCostEstimator() {
|
||||
|
||||
device_cost_impl_.emplace(kSlice,
|
||||
wrap(&OpLevelCostEstimator::PredictGatherOrSlice));
|
||||
device_cost_impl_.emplace(kStridedSlice,
|
||||
wrap(&OpLevelCostEstimator::PredictGatherOrSlice));
|
||||
|
||||
device_cost_impl_.emplace(kPlaceholder,
|
||||
wrap(&OpLevelCostEstimator::PredictIdentity));
|
||||
@ -1799,15 +1805,20 @@ Costs OpLevelCostEstimator::PredictGatherOrSlice(
|
||||
|
||||
const double output_size = CalculateOutputSize(op_info, &unknown_shapes);
|
||||
double input_size = output_size;
|
||||
int begin_input_index = 1, end_input_index;
|
||||
if (op_info.op() == "Slice") {
|
||||
// Add 'begin' & 'size' tensors sizes.
|
||||
input_size +=
|
||||
CalculateTensorElementCount(op_info.inputs(1), &unknown_shapes) +
|
||||
CalculateTensorElementCount(op_info.inputs(2), &unknown_shapes);
|
||||
// Slice: 'input' (omitted), 'begin', 'size'
|
||||
end_input_index = 3;
|
||||
} else if (op_info.op() == "StridedSlice") {
|
||||
// StridedSlice: 'input' (omitted), 'begin', 'end', 'strides'
|
||||
end_input_index = 4;
|
||||
} else {
|
||||
// Assuming this is "Gather" or "GatherV2" op, add 'indices' size.
|
||||
// Gather, GatherV2, GatherNd: 'params' (omitted), 'indices'
|
||||
end_input_index = 2;
|
||||
}
|
||||
for (int i = begin_input_index; i < end_input_index; ++i) {
|
||||
input_size +=
|
||||
CalculateTensorElementCount(op_info.inputs(1), &unknown_shapes);
|
||||
CalculateTensorElementCount(op_info.inputs(i), &unknown_shapes);
|
||||
}
|
||||
|
||||
Costs costs =
|
||||
|
@ -641,22 +641,26 @@ TEST_F(OpLevelCostEstimatorTest, TestPersistentOpCosts) {
|
||||
}
|
||||
|
||||
TEST_F(OpLevelCostEstimatorTest, TestGatherCosts) {
|
||||
OpContext op_context;
|
||||
SetCpuDevice(&op_context.op_info);
|
||||
op_context.op_info.set_op("Gather");
|
||||
std::vector<std::string> gather_ops = {"Gather", "GatherNd", "GatherV2"};
|
||||
|
||||
// Huge first input shouldn't affect Gather execution and memory costs.
|
||||
DescribeArbitraryRankInput({10000000, 10}, DT_FLOAT, &op_context.op_info);
|
||||
DescribeArbitraryRankInput({16}, DT_INT64, &op_context.op_info);
|
||||
DescribeArbitraryRankOutput({16, 10}, DT_FLOAT, &op_context.op_info);
|
||||
for (const auto& op : gather_ops) {
|
||||
OpContext op_context;
|
||||
SetCpuDevice(&op_context.op_info);
|
||||
op_context.op_info.set_op(op);
|
||||
|
||||
auto cost = estimator_.PredictCosts(op_context);
|
||||
EXPECT_EQ(Costs::Duration(130), cost.memory_time);
|
||||
EXPECT_EQ(Costs::Duration(16), cost.compute_time);
|
||||
EXPECT_EQ(Costs::Duration(146), cost.execution_time);
|
||||
EXPECT_EQ(1, cost.num_ops_total);
|
||||
EXPECT_FALSE(cost.inaccurate);
|
||||
EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
|
||||
// Huge first input shouldn't affect Gather execution and memory costs.
|
||||
DescribeArbitraryRankInput({10000000, 10}, DT_FLOAT, &op_context.op_info);
|
||||
DescribeArbitraryRankInput({16}, DT_INT64, &op_context.op_info);
|
||||
DescribeArbitraryRankOutput({16, 10}, DT_FLOAT, &op_context.op_info);
|
||||
|
||||
auto cost = estimator_.PredictCosts(op_context);
|
||||
EXPECT_EQ(Costs::Duration(130), cost.memory_time);
|
||||
EXPECT_EQ(Costs::Duration(16), cost.compute_time);
|
||||
EXPECT_EQ(Costs::Duration(146), cost.execution_time);
|
||||
EXPECT_EQ(1, cost.num_ops_total);
|
||||
EXPECT_FALSE(cost.inaccurate);
|
||||
EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(OpLevelCostEstimatorTest, TestGatherCostsWithoutOutput) {
|
||||
@ -697,6 +701,27 @@ TEST_F(OpLevelCostEstimatorTest, TestSliceCosts) {
|
||||
EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
|
||||
}
|
||||
|
||||
TEST_F(OpLevelCostEstimatorTest, TestStridedSliceCosts) {
|
||||
OpContext op_context;
|
||||
SetCpuDevice(&op_context.op_info);
|
||||
op_context.op_info.set_op("StridedSlice");
|
||||
|
||||
// Huge first input shouldn't affect StridedSlice execution and memory costs.
|
||||
DescribeArbitraryRankInput({10000000, 10}, DT_FLOAT, &op_context.op_info);
|
||||
DescribeArbitraryRankInput({2}, DT_INT64, &op_context.op_info);
|
||||
DescribeArbitraryRankInput({2}, DT_INT64, &op_context.op_info);
|
||||
DescribeArbitraryRankInput({2}, DT_INT64, &op_context.op_info);
|
||||
DescribeArbitraryRankOutput({10, 10}, DT_FLOAT, &op_context.op_info);
|
||||
|
||||
auto cost = estimator_.PredictCosts(op_context);
|
||||
EXPECT_EQ(Costs::Duration(81), cost.memory_time);
|
||||
EXPECT_EQ(Costs::Duration(10), cost.compute_time);
|
||||
EXPECT_EQ(Costs::Duration(91), cost.execution_time);
|
||||
EXPECT_EQ(1, cost.num_ops_total);
|
||||
EXPECT_FALSE(cost.inaccurate);
|
||||
EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
|
||||
}
|
||||
|
||||
TEST_F(OpLevelCostEstimatorTest, TestScatterOps) {
|
||||
std::vector<string> scatter_ops = {"ScatterAdd", "ScatterDiv", "ScatterMax",
|
||||
"ScatterMin", "ScatterMul", "ScatterSub",
|
||||
|
Loading…
x
Reference in New Issue
Block a user