diff --git a/tensorflow/core/grappler/costs/analytical_cost_estimator_test.cc b/tensorflow/core/grappler/costs/analytical_cost_estimator_test.cc index eb7ee8dc0a1..2c319b6c6b0 100644 --- a/tensorflow/core/grappler/costs/analytical_cost_estimator_test.cc +++ b/tensorflow/core/grappler/costs/analytical_cost_estimator_test.cc @@ -102,7 +102,7 @@ TEST_F(AnalyticalCostEstimatorTest, SimpleTest) { Costs summary; TF_ASSERT_OK(estimator.PredictCosts(item.graph, &run_metadata, &summary)); - EXPECT_EQ(Costs::NanoSeconds(9151), summary.execution_time); + EXPECT_EQ(Costs::NanoSeconds(9157), summary.execution_time); // Note there are totally 17 nodes (RandomUniform creates 2 nodes), but // grappler will not process "label", therefore we have 15 here instead EXPECT_EQ(15, summary.num_ops_total); diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc index 59d20f1fb9a..1e2e160955c 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc @@ -27,7 +27,6 @@ namespace tensorflow { namespace grappler { constexpr int kOpsPerMac = 2; -constexpr char kConst[] = "Const"; constexpr char kGuaranteeConst[] = "GuaranteeConst"; constexpr char kConv2d[] = "Conv2D"; constexpr char kConv2dBackpropFilter[] = "Conv2DBackpropFilter"; @@ -50,8 +49,6 @@ constexpr char kSqueeze[] = "Squeeze"; constexpr char kRecv[] = "_Recv"; constexpr char kSend[] = "_Send"; constexpr char kBatchMatMul[] = "BatchMatMul"; -constexpr char kVariable[] = "Variable"; -constexpr char kVariableV2[] = "VariableV2"; constexpr char kRank[] = "Rank"; constexpr char kShape[] = "Shape"; constexpr char kShapeN[] = "ShapeN"; @@ -68,6 +65,13 @@ constexpr char kAvgPoolGrad[] = "AvgPoolGrad"; constexpr char kFusedBatchNorm[] = "FusedBatchNorm"; constexpr char kFusedBatchNormGrad[] = "FusedBatchNormGrad"; constexpr char kQuantizedMatMulV2[] = "QuantizedMatMulV2"; +// Persistent ops. +constexpr char kConst[] = "Const"; +constexpr char kVariable[] = "Variable"; +constexpr char kVariableV2[] = "VariableV2"; +constexpr char kAutoReloadVariable[] = "AutoReloadVariable"; +constexpr char kVarHandleOp[] = "VarHandleOp"; +constexpr char kReadVariableOp[] = "ReadVariableOp"; static const Costs::Duration kMinComputeTime(1); @@ -259,10 +263,6 @@ OpLevelCostEstimator::OpLevelCostEstimator() { {kRecv, wrap(&OpLevelCostEstimator::PredictIdentity)}, {kSend, wrap(&OpLevelCostEstimator::PredictIdentity)}, - {kConst, wrap(&OpLevelCostEstimator::PredictVariable)}, - {kVariable, wrap(&OpLevelCostEstimator::PredictVariable)}, - {kVariableV2, wrap(&OpLevelCostEstimator::PredictVariable)}, - {kRank, wrap(&OpLevelCostEstimator::PredictMetadata)}, {kShape, wrap(&OpLevelCostEstimator::PredictMetadata)}, {kShapeN, wrap(&OpLevelCostEstimator::PredictMetadata)}, @@ -276,6 +276,11 @@ OpLevelCostEstimator::OpLevelCostEstimator() { wrap(&OpLevelCostEstimator::PredictFusedBatchNormGrad)}, }; + persistent_ops_ = { + kConst, kVariable, kVariableV2, kAutoReloadVariable, + kVarHandleOp, kReadVariableOp, + }; + #define EIGEN_COST(X) Eigen::internal::functor_traits::Cost // Quantize = apply min and max bounds, multiply by scale factor and round. @@ -363,21 +368,25 @@ OpLevelCostEstimator::OpLevelCostEstimator() { Costs OpLevelCostEstimator::PredictCosts(const OpContext& op_context) const { const auto& op_info = op_context.op_info; auto it = device_cost_impl_.find(op_info.op()); - if (it == device_cost_impl_.end()) { - if (elementwise_ops_.find(op_info.op()) != elementwise_ops_.end()) { - return PredictCwiseOp(op_context); - } - - VLOG(1) << "Missing accurate estimator for op: " << op_info.op(); - - return PredictCostOfAnUnknownOp(op_context); + if (it != device_cost_impl_.end()) { + std::function estimator = it->second; + Costs costs = estimator(op_context); + VLOG(1) << "Operation " << op_info.op() << " takes " + << costs.execution_time.count() << " ns."; + return costs; } - std::function estimator = it->second; - Costs costs = estimator(op_context); - VLOG(1) << "Operation " << op_info.op() << " takes " - << costs.execution_time.count() << " ns."; - return costs; + if (persistent_ops_.find(op_info.op()) != persistent_ops_.end()) { + return PredictVariable(op_context); + } + + if (elementwise_ops_.find(op_info.op()) != elementwise_ops_.end()) { + return PredictCwiseOp(op_context); + } + + VLOG(1) << "Missing accurate estimator for op: " << op_info.op(); + + return PredictCostOfAnUnknownOp(op_context); } DeviceInfo OpLevelCostEstimator::GetDeviceInfo( @@ -1240,7 +1249,7 @@ Costs OpLevelCostEstimator::PredictVariable(const OpContext& op_context) const { result.num_ops_with_unknown_shapes = result.inaccurate; result.compute_time = kMinComputeTime; - result.execution_time = result.execution_time; + result.execution_time = result.compute_time; return result; } diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.h b/tensorflow/core/grappler/costs/op_level_cost_estimator.h index f8ba8c6637d..ace8fb218c7 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator.h +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.h @@ -193,6 +193,7 @@ class OpLevelCostEstimator { // If true, assume compute and memory overlap; hence, the op cost is max of // compute_time and memory_time, insteaf of sum of those two. bool compute_memory_overlap_; + std::set persistent_ops_; private: friend class OpLevelCostEstimatorTest; diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc index aa0fc9d6c2a..04c6ada2bf6 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc @@ -499,6 +499,26 @@ class OpLevelCostEstimatorTest : public ::testing::Test { OpLevelCostEstimator estimator_; }; +TEST_F(OpLevelCostEstimatorTest, TestPersistentOpCosts) { + OpContext op_context; + SetCpuDevice(&op_context.op_info); + std::unordered_set persisent_ops = { + "Const", "Variable", "VariableV2", "AutoReloadVariable", + "VarHandleOp", "ReadVariableOp", + }; + // Minmum cost for all persistent ops. + for (const auto& op : persisent_ops) { + op_context.op_info.set_op(op); + auto cost = estimator_.PredictCosts(op_context); + EXPECT_EQ(Costs::Duration(0), cost.memory_time); + EXPECT_EQ(Costs::Duration(1), cost.compute_time); + EXPECT_EQ(Costs::Duration(1), cost.execution_time); + EXPECT_EQ(1, cost.num_ops_total); + EXPECT_FALSE(cost.inaccurate); + EXPECT_EQ(0, cost.num_ops_with_unknown_shapes); + } +} + TEST_F(OpLevelCostEstimatorTest, TestGatherCosts) { OpContext op_context; SetCpuDevice(&op_context.op_info);