From f722aee786f51962f07eaf02653ff509a1647b1c Mon Sep 17 00:00:00 2001 From: Doe Hyun Yoon Date: Tue, 5 Feb 2019 10:21:32 -0800 Subject: [PATCH] Return minimum costs for all persistent ops: Const and Vars. Currently, PredictCosts() returns minimum costs for Const, Var, and VarV2, but not other var ops. Also, the logic was incorrect in that it sets minimum cost to compute time, but zero to execution time. This CL fixes these. PiperOrigin-RevId: 232510666 --- .../costs/analytical_cost_estimator_test.cc | 2 +- .../grappler/costs/op_level_cost_estimator.cc | 51 +++++++++++-------- .../grappler/costs/op_level_cost_estimator.h | 1 + .../costs/op_level_cost_estimator_test.cc | 20 ++++++++ 4 files changed, 52 insertions(+), 22 deletions(-) diff --git a/tensorflow/core/grappler/costs/analytical_cost_estimator_test.cc b/tensorflow/core/grappler/costs/analytical_cost_estimator_test.cc index eb7ee8dc0a1..2c319b6c6b0 100644 --- a/tensorflow/core/grappler/costs/analytical_cost_estimator_test.cc +++ b/tensorflow/core/grappler/costs/analytical_cost_estimator_test.cc @@ -102,7 +102,7 @@ TEST_F(AnalyticalCostEstimatorTest, SimpleTest) { Costs summary; TF_ASSERT_OK(estimator.PredictCosts(item.graph, &run_metadata, &summary)); - EXPECT_EQ(Costs::NanoSeconds(9151), summary.execution_time); + EXPECT_EQ(Costs::NanoSeconds(9157), summary.execution_time); // Note there are totally 17 nodes (RandomUniform creates 2 nodes), but // grappler will not process "label", therefore we have 15 here instead EXPECT_EQ(15, summary.num_ops_total); diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc index 59d20f1fb9a..1e2e160955c 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc @@ -27,7 +27,6 @@ namespace tensorflow { namespace grappler { constexpr int kOpsPerMac = 2; -constexpr char kConst[] = "Const"; constexpr char kGuaranteeConst[] = "GuaranteeConst"; constexpr char kConv2d[] = "Conv2D"; constexpr char kConv2dBackpropFilter[] = "Conv2DBackpropFilter"; @@ -50,8 +49,6 @@ constexpr char kSqueeze[] = "Squeeze"; constexpr char kRecv[] = "_Recv"; constexpr char kSend[] = "_Send"; constexpr char kBatchMatMul[] = "BatchMatMul"; -constexpr char kVariable[] = "Variable"; -constexpr char kVariableV2[] = "VariableV2"; constexpr char kRank[] = "Rank"; constexpr char kShape[] = "Shape"; constexpr char kShapeN[] = "ShapeN"; @@ -68,6 +65,13 @@ constexpr char kAvgPoolGrad[] = "AvgPoolGrad"; constexpr char kFusedBatchNorm[] = "FusedBatchNorm"; constexpr char kFusedBatchNormGrad[] = "FusedBatchNormGrad"; constexpr char kQuantizedMatMulV2[] = "QuantizedMatMulV2"; +// Persistent ops. +constexpr char kConst[] = "Const"; +constexpr char kVariable[] = "Variable"; +constexpr char kVariableV2[] = "VariableV2"; +constexpr char kAutoReloadVariable[] = "AutoReloadVariable"; +constexpr char kVarHandleOp[] = "VarHandleOp"; +constexpr char kReadVariableOp[] = "ReadVariableOp"; static const Costs::Duration kMinComputeTime(1); @@ -259,10 +263,6 @@ OpLevelCostEstimator::OpLevelCostEstimator() { {kRecv, wrap(&OpLevelCostEstimator::PredictIdentity)}, {kSend, wrap(&OpLevelCostEstimator::PredictIdentity)}, - {kConst, wrap(&OpLevelCostEstimator::PredictVariable)}, - {kVariable, wrap(&OpLevelCostEstimator::PredictVariable)}, - {kVariableV2, wrap(&OpLevelCostEstimator::PredictVariable)}, - {kRank, wrap(&OpLevelCostEstimator::PredictMetadata)}, {kShape, wrap(&OpLevelCostEstimator::PredictMetadata)}, {kShapeN, wrap(&OpLevelCostEstimator::PredictMetadata)}, @@ -276,6 +276,11 @@ OpLevelCostEstimator::OpLevelCostEstimator() { wrap(&OpLevelCostEstimator::PredictFusedBatchNormGrad)}, }; + persistent_ops_ = { + kConst, kVariable, kVariableV2, kAutoReloadVariable, + kVarHandleOp, kReadVariableOp, + }; + #define EIGEN_COST(X) Eigen::internal::functor_traits::Cost // Quantize = apply min and max bounds, multiply by scale factor and round. @@ -363,21 +368,25 @@ OpLevelCostEstimator::OpLevelCostEstimator() { Costs OpLevelCostEstimator::PredictCosts(const OpContext& op_context) const { const auto& op_info = op_context.op_info; auto it = device_cost_impl_.find(op_info.op()); - if (it == device_cost_impl_.end()) { - if (elementwise_ops_.find(op_info.op()) != elementwise_ops_.end()) { - return PredictCwiseOp(op_context); - } - - VLOG(1) << "Missing accurate estimator for op: " << op_info.op(); - - return PredictCostOfAnUnknownOp(op_context); + if (it != device_cost_impl_.end()) { + std::function estimator = it->second; + Costs costs = estimator(op_context); + VLOG(1) << "Operation " << op_info.op() << " takes " + << costs.execution_time.count() << " ns."; + return costs; } - std::function estimator = it->second; - Costs costs = estimator(op_context); - VLOG(1) << "Operation " << op_info.op() << " takes " - << costs.execution_time.count() << " ns."; - return costs; + if (persistent_ops_.find(op_info.op()) != persistent_ops_.end()) { + return PredictVariable(op_context); + } + + if (elementwise_ops_.find(op_info.op()) != elementwise_ops_.end()) { + return PredictCwiseOp(op_context); + } + + VLOG(1) << "Missing accurate estimator for op: " << op_info.op(); + + return PredictCostOfAnUnknownOp(op_context); } DeviceInfo OpLevelCostEstimator::GetDeviceInfo( @@ -1240,7 +1249,7 @@ Costs OpLevelCostEstimator::PredictVariable(const OpContext& op_context) const { result.num_ops_with_unknown_shapes = result.inaccurate; result.compute_time = kMinComputeTime; - result.execution_time = result.execution_time; + result.execution_time = result.compute_time; return result; } diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.h b/tensorflow/core/grappler/costs/op_level_cost_estimator.h index f8ba8c6637d..ace8fb218c7 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator.h +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.h @@ -193,6 +193,7 @@ class OpLevelCostEstimator { // If true, assume compute and memory overlap; hence, the op cost is max of // compute_time and memory_time, insteaf of sum of those two. bool compute_memory_overlap_; + std::set persistent_ops_; private: friend class OpLevelCostEstimatorTest; diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc index aa0fc9d6c2a..04c6ada2bf6 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc @@ -499,6 +499,26 @@ class OpLevelCostEstimatorTest : public ::testing::Test { OpLevelCostEstimator estimator_; }; +TEST_F(OpLevelCostEstimatorTest, TestPersistentOpCosts) { + OpContext op_context; + SetCpuDevice(&op_context.op_info); + std::unordered_set persisent_ops = { + "Const", "Variable", "VariableV2", "AutoReloadVariable", + "VarHandleOp", "ReadVariableOp", + }; + // Minmum cost for all persistent ops. + for (const auto& op : persisent_ops) { + op_context.op_info.set_op(op); + auto cost = estimator_.PredictCosts(op_context); + EXPECT_EQ(Costs::Duration(0), cost.memory_time); + EXPECT_EQ(Costs::Duration(1), cost.compute_time); + EXPECT_EQ(Costs::Duration(1), cost.execution_time); + EXPECT_EQ(1, cost.num_ops_total); + EXPECT_FALSE(cost.inaccurate); + EXPECT_EQ(0, cost.num_ops_with_unknown_shapes); + } +} + TEST_F(OpLevelCostEstimatorTest, TestGatherCosts) { OpContext op_context; SetCpuDevice(&op_context.op_info);