Return minimum costs for all persistent ops: Const and Vars.

Currently, PredictCosts() returns minimum costs for Const, Var, and VarV2, but not other var ops. Also, the logic was incorrect in that it sets minimum cost to compute time, but zero to execution time.

This CL fixes these.

PiperOrigin-RevId: 232510666
This commit is contained in:
Doe Hyun Yoon 2019-02-05 10:21:32 -08:00 committed by TensorFlower Gardener
parent e1f2db44f4
commit f722aee786
4 changed files with 52 additions and 22 deletions

View File

@ -102,7 +102,7 @@ TEST_F(AnalyticalCostEstimatorTest, SimpleTest) {
Costs summary; Costs summary;
TF_ASSERT_OK(estimator.PredictCosts(item.graph, &run_metadata, &summary)); TF_ASSERT_OK(estimator.PredictCosts(item.graph, &run_metadata, &summary));
EXPECT_EQ(Costs::NanoSeconds(9151), summary.execution_time); EXPECT_EQ(Costs::NanoSeconds(9157), summary.execution_time);
// Note there are totally 17 nodes (RandomUniform creates 2 nodes), but // Note there are totally 17 nodes (RandomUniform creates 2 nodes), but
// grappler will not process "label", therefore we have 15 here instead // grappler will not process "label", therefore we have 15 here instead
EXPECT_EQ(15, summary.num_ops_total); EXPECT_EQ(15, summary.num_ops_total);

View File

@ -27,7 +27,6 @@ namespace tensorflow {
namespace grappler { namespace grappler {
constexpr int kOpsPerMac = 2; constexpr int kOpsPerMac = 2;
constexpr char kConst[] = "Const";
constexpr char kGuaranteeConst[] = "GuaranteeConst"; constexpr char kGuaranteeConst[] = "GuaranteeConst";
constexpr char kConv2d[] = "Conv2D"; constexpr char kConv2d[] = "Conv2D";
constexpr char kConv2dBackpropFilter[] = "Conv2DBackpropFilter"; constexpr char kConv2dBackpropFilter[] = "Conv2DBackpropFilter";
@ -50,8 +49,6 @@ constexpr char kSqueeze[] = "Squeeze";
constexpr char kRecv[] = "_Recv"; constexpr char kRecv[] = "_Recv";
constexpr char kSend[] = "_Send"; constexpr char kSend[] = "_Send";
constexpr char kBatchMatMul[] = "BatchMatMul"; constexpr char kBatchMatMul[] = "BatchMatMul";
constexpr char kVariable[] = "Variable";
constexpr char kVariableV2[] = "VariableV2";
constexpr char kRank[] = "Rank"; constexpr char kRank[] = "Rank";
constexpr char kShape[] = "Shape"; constexpr char kShape[] = "Shape";
constexpr char kShapeN[] = "ShapeN"; constexpr char kShapeN[] = "ShapeN";
@ -68,6 +65,13 @@ constexpr char kAvgPoolGrad[] = "AvgPoolGrad";
constexpr char kFusedBatchNorm[] = "FusedBatchNorm"; constexpr char kFusedBatchNorm[] = "FusedBatchNorm";
constexpr char kFusedBatchNormGrad[] = "FusedBatchNormGrad"; constexpr char kFusedBatchNormGrad[] = "FusedBatchNormGrad";
constexpr char kQuantizedMatMulV2[] = "QuantizedMatMulV2"; constexpr char kQuantizedMatMulV2[] = "QuantizedMatMulV2";
// Persistent ops.
constexpr char kConst[] = "Const";
constexpr char kVariable[] = "Variable";
constexpr char kVariableV2[] = "VariableV2";
constexpr char kAutoReloadVariable[] = "AutoReloadVariable";
constexpr char kVarHandleOp[] = "VarHandleOp";
constexpr char kReadVariableOp[] = "ReadVariableOp";
static const Costs::Duration kMinComputeTime(1); static const Costs::Duration kMinComputeTime(1);
@ -259,10 +263,6 @@ OpLevelCostEstimator::OpLevelCostEstimator() {
{kRecv, wrap(&OpLevelCostEstimator::PredictIdentity)}, {kRecv, wrap(&OpLevelCostEstimator::PredictIdentity)},
{kSend, wrap(&OpLevelCostEstimator::PredictIdentity)}, {kSend, wrap(&OpLevelCostEstimator::PredictIdentity)},
{kConst, wrap(&OpLevelCostEstimator::PredictVariable)},
{kVariable, wrap(&OpLevelCostEstimator::PredictVariable)},
{kVariableV2, wrap(&OpLevelCostEstimator::PredictVariable)},
{kRank, wrap(&OpLevelCostEstimator::PredictMetadata)}, {kRank, wrap(&OpLevelCostEstimator::PredictMetadata)},
{kShape, wrap(&OpLevelCostEstimator::PredictMetadata)}, {kShape, wrap(&OpLevelCostEstimator::PredictMetadata)},
{kShapeN, wrap(&OpLevelCostEstimator::PredictMetadata)}, {kShapeN, wrap(&OpLevelCostEstimator::PredictMetadata)},
@ -276,6 +276,11 @@ OpLevelCostEstimator::OpLevelCostEstimator() {
wrap(&OpLevelCostEstimator::PredictFusedBatchNormGrad)}, wrap(&OpLevelCostEstimator::PredictFusedBatchNormGrad)},
}; };
persistent_ops_ = {
kConst, kVariable, kVariableV2, kAutoReloadVariable,
kVarHandleOp, kReadVariableOp,
};
#define EIGEN_COST(X) Eigen::internal::functor_traits<Eigen::internal::X>::Cost #define EIGEN_COST(X) Eigen::internal::functor_traits<Eigen::internal::X>::Cost
// Quantize = apply min and max bounds, multiply by scale factor and round. // Quantize = apply min and max bounds, multiply by scale factor and round.
@ -363,7 +368,18 @@ OpLevelCostEstimator::OpLevelCostEstimator() {
Costs OpLevelCostEstimator::PredictCosts(const OpContext& op_context) const { Costs OpLevelCostEstimator::PredictCosts(const OpContext& op_context) const {
const auto& op_info = op_context.op_info; const auto& op_info = op_context.op_info;
auto it = device_cost_impl_.find(op_info.op()); auto it = device_cost_impl_.find(op_info.op());
if (it == device_cost_impl_.end()) { if (it != device_cost_impl_.end()) {
std::function<Costs(const OpContext&)> estimator = it->second;
Costs costs = estimator(op_context);
VLOG(1) << "Operation " << op_info.op() << " takes "
<< costs.execution_time.count() << " ns.";
return costs;
}
if (persistent_ops_.find(op_info.op()) != persistent_ops_.end()) {
return PredictVariable(op_context);
}
if (elementwise_ops_.find(op_info.op()) != elementwise_ops_.end()) { if (elementwise_ops_.find(op_info.op()) != elementwise_ops_.end()) {
return PredictCwiseOp(op_context); return PredictCwiseOp(op_context);
} }
@ -371,13 +387,6 @@ Costs OpLevelCostEstimator::PredictCosts(const OpContext& op_context) const {
VLOG(1) << "Missing accurate estimator for op: " << op_info.op(); VLOG(1) << "Missing accurate estimator for op: " << op_info.op();
return PredictCostOfAnUnknownOp(op_context); return PredictCostOfAnUnknownOp(op_context);
}
std::function<Costs(const OpContext&)> estimator = it->second;
Costs costs = estimator(op_context);
VLOG(1) << "Operation " << op_info.op() << " takes "
<< costs.execution_time.count() << " ns.";
return costs;
} }
DeviceInfo OpLevelCostEstimator::GetDeviceInfo( DeviceInfo OpLevelCostEstimator::GetDeviceInfo(
@ -1240,7 +1249,7 @@ Costs OpLevelCostEstimator::PredictVariable(const OpContext& op_context) const {
result.num_ops_with_unknown_shapes = result.inaccurate; result.num_ops_with_unknown_shapes = result.inaccurate;
result.compute_time = kMinComputeTime; result.compute_time = kMinComputeTime;
result.execution_time = result.execution_time; result.execution_time = result.compute_time;
return result; return result;
} }

View File

@ -193,6 +193,7 @@ class OpLevelCostEstimator {
// If true, assume compute and memory overlap; hence, the op cost is max of // If true, assume compute and memory overlap; hence, the op cost is max of
// compute_time and memory_time, insteaf of sum of those two. // compute_time and memory_time, insteaf of sum of those two.
bool compute_memory_overlap_; bool compute_memory_overlap_;
std::set<string> persistent_ops_;
private: private:
friend class OpLevelCostEstimatorTest; friend class OpLevelCostEstimatorTest;

View File

@ -499,6 +499,26 @@ class OpLevelCostEstimatorTest : public ::testing::Test {
OpLevelCostEstimator estimator_; OpLevelCostEstimator estimator_;
}; };
TEST_F(OpLevelCostEstimatorTest, TestPersistentOpCosts) {
OpContext op_context;
SetCpuDevice(&op_context.op_info);
std::unordered_set<string> persisent_ops = {
"Const", "Variable", "VariableV2", "AutoReloadVariable",
"VarHandleOp", "ReadVariableOp",
};
// Minmum cost for all persistent ops.
for (const auto& op : persisent_ops) {
op_context.op_info.set_op(op);
auto cost = estimator_.PredictCosts(op_context);
EXPECT_EQ(Costs::Duration(0), cost.memory_time);
EXPECT_EQ(Costs::Duration(1), cost.compute_time);
EXPECT_EQ(Costs::Duration(1), cost.execution_time);
EXPECT_EQ(1, cost.num_ops_total);
EXPECT_FALSE(cost.inaccurate);
EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
}
}
TEST_F(OpLevelCostEstimatorTest, TestGatherCosts) { TEST_F(OpLevelCostEstimatorTest, TestGatherCosts) {
OpContext op_context; OpContext op_context;
SetCpuDevice(&op_context.op_info); SetCpuDevice(&op_context.op_info);