Return minimum costs for all persistent ops: Const and Vars.
Currently, PredictCosts() returns minimum costs for Const, Var, and VarV2, but not other var ops. Also, the logic was incorrect in that it sets minimum cost to compute time, but zero to execution time. This CL fixes these. PiperOrigin-RevId: 232510666
This commit is contained in:
parent
e1f2db44f4
commit
f722aee786
@ -102,7 +102,7 @@ TEST_F(AnalyticalCostEstimatorTest, SimpleTest) {
|
||||
Costs summary;
|
||||
TF_ASSERT_OK(estimator.PredictCosts(item.graph, &run_metadata, &summary));
|
||||
|
||||
EXPECT_EQ(Costs::NanoSeconds(9151), summary.execution_time);
|
||||
EXPECT_EQ(Costs::NanoSeconds(9157), summary.execution_time);
|
||||
// Note there are totally 17 nodes (RandomUniform creates 2 nodes), but
|
||||
// grappler will not process "label", therefore we have 15 here instead
|
||||
EXPECT_EQ(15, summary.num_ops_total);
|
||||
|
@ -27,7 +27,6 @@ namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
constexpr int kOpsPerMac = 2;
|
||||
constexpr char kConst[] = "Const";
|
||||
constexpr char kGuaranteeConst[] = "GuaranteeConst";
|
||||
constexpr char kConv2d[] = "Conv2D";
|
||||
constexpr char kConv2dBackpropFilter[] = "Conv2DBackpropFilter";
|
||||
@ -50,8 +49,6 @@ constexpr char kSqueeze[] = "Squeeze";
|
||||
constexpr char kRecv[] = "_Recv";
|
||||
constexpr char kSend[] = "_Send";
|
||||
constexpr char kBatchMatMul[] = "BatchMatMul";
|
||||
constexpr char kVariable[] = "Variable";
|
||||
constexpr char kVariableV2[] = "VariableV2";
|
||||
constexpr char kRank[] = "Rank";
|
||||
constexpr char kShape[] = "Shape";
|
||||
constexpr char kShapeN[] = "ShapeN";
|
||||
@ -68,6 +65,13 @@ constexpr char kAvgPoolGrad[] = "AvgPoolGrad";
|
||||
constexpr char kFusedBatchNorm[] = "FusedBatchNorm";
|
||||
constexpr char kFusedBatchNormGrad[] = "FusedBatchNormGrad";
|
||||
constexpr char kQuantizedMatMulV2[] = "QuantizedMatMulV2";
|
||||
// Persistent ops.
|
||||
constexpr char kConst[] = "Const";
|
||||
constexpr char kVariable[] = "Variable";
|
||||
constexpr char kVariableV2[] = "VariableV2";
|
||||
constexpr char kAutoReloadVariable[] = "AutoReloadVariable";
|
||||
constexpr char kVarHandleOp[] = "VarHandleOp";
|
||||
constexpr char kReadVariableOp[] = "ReadVariableOp";
|
||||
|
||||
static const Costs::Duration kMinComputeTime(1);
|
||||
|
||||
@ -259,10 +263,6 @@ OpLevelCostEstimator::OpLevelCostEstimator() {
|
||||
{kRecv, wrap(&OpLevelCostEstimator::PredictIdentity)},
|
||||
{kSend, wrap(&OpLevelCostEstimator::PredictIdentity)},
|
||||
|
||||
{kConst, wrap(&OpLevelCostEstimator::PredictVariable)},
|
||||
{kVariable, wrap(&OpLevelCostEstimator::PredictVariable)},
|
||||
{kVariableV2, wrap(&OpLevelCostEstimator::PredictVariable)},
|
||||
|
||||
{kRank, wrap(&OpLevelCostEstimator::PredictMetadata)},
|
||||
{kShape, wrap(&OpLevelCostEstimator::PredictMetadata)},
|
||||
{kShapeN, wrap(&OpLevelCostEstimator::PredictMetadata)},
|
||||
@ -276,6 +276,11 @@ OpLevelCostEstimator::OpLevelCostEstimator() {
|
||||
wrap(&OpLevelCostEstimator::PredictFusedBatchNormGrad)},
|
||||
};
|
||||
|
||||
persistent_ops_ = {
|
||||
kConst, kVariable, kVariableV2, kAutoReloadVariable,
|
||||
kVarHandleOp, kReadVariableOp,
|
||||
};
|
||||
|
||||
#define EIGEN_COST(X) Eigen::internal::functor_traits<Eigen::internal::X>::Cost
|
||||
|
||||
// Quantize = apply min and max bounds, multiply by scale factor and round.
|
||||
@ -363,7 +368,18 @@ OpLevelCostEstimator::OpLevelCostEstimator() {
|
||||
Costs OpLevelCostEstimator::PredictCosts(const OpContext& op_context) const {
|
||||
const auto& op_info = op_context.op_info;
|
||||
auto it = device_cost_impl_.find(op_info.op());
|
||||
if (it == device_cost_impl_.end()) {
|
||||
if (it != device_cost_impl_.end()) {
|
||||
std::function<Costs(const OpContext&)> estimator = it->second;
|
||||
Costs costs = estimator(op_context);
|
||||
VLOG(1) << "Operation " << op_info.op() << " takes "
|
||||
<< costs.execution_time.count() << " ns.";
|
||||
return costs;
|
||||
}
|
||||
|
||||
if (persistent_ops_.find(op_info.op()) != persistent_ops_.end()) {
|
||||
return PredictVariable(op_context);
|
||||
}
|
||||
|
||||
if (elementwise_ops_.find(op_info.op()) != elementwise_ops_.end()) {
|
||||
return PredictCwiseOp(op_context);
|
||||
}
|
||||
@ -371,13 +387,6 @@ Costs OpLevelCostEstimator::PredictCosts(const OpContext& op_context) const {
|
||||
VLOG(1) << "Missing accurate estimator for op: " << op_info.op();
|
||||
|
||||
return PredictCostOfAnUnknownOp(op_context);
|
||||
}
|
||||
|
||||
std::function<Costs(const OpContext&)> estimator = it->second;
|
||||
Costs costs = estimator(op_context);
|
||||
VLOG(1) << "Operation " << op_info.op() << " takes "
|
||||
<< costs.execution_time.count() << " ns.";
|
||||
return costs;
|
||||
}
|
||||
|
||||
DeviceInfo OpLevelCostEstimator::GetDeviceInfo(
|
||||
@ -1240,7 +1249,7 @@ Costs OpLevelCostEstimator::PredictVariable(const OpContext& op_context) const {
|
||||
result.num_ops_with_unknown_shapes = result.inaccurate;
|
||||
|
||||
result.compute_time = kMinComputeTime;
|
||||
result.execution_time = result.execution_time;
|
||||
result.execution_time = result.compute_time;
|
||||
return result;
|
||||
}
|
||||
|
||||
|
@ -193,6 +193,7 @@ class OpLevelCostEstimator {
|
||||
// If true, assume compute and memory overlap; hence, the op cost is max of
|
||||
// compute_time and memory_time, insteaf of sum of those two.
|
||||
bool compute_memory_overlap_;
|
||||
std::set<string> persistent_ops_;
|
||||
|
||||
private:
|
||||
friend class OpLevelCostEstimatorTest;
|
||||
|
@ -499,6 +499,26 @@ class OpLevelCostEstimatorTest : public ::testing::Test {
|
||||
OpLevelCostEstimator estimator_;
|
||||
};
|
||||
|
||||
TEST_F(OpLevelCostEstimatorTest, TestPersistentOpCosts) {
|
||||
OpContext op_context;
|
||||
SetCpuDevice(&op_context.op_info);
|
||||
std::unordered_set<string> persisent_ops = {
|
||||
"Const", "Variable", "VariableV2", "AutoReloadVariable",
|
||||
"VarHandleOp", "ReadVariableOp",
|
||||
};
|
||||
// Minmum cost for all persistent ops.
|
||||
for (const auto& op : persisent_ops) {
|
||||
op_context.op_info.set_op(op);
|
||||
auto cost = estimator_.PredictCosts(op_context);
|
||||
EXPECT_EQ(Costs::Duration(0), cost.memory_time);
|
||||
EXPECT_EQ(Costs::Duration(1), cost.compute_time);
|
||||
EXPECT_EQ(Costs::Duration(1), cost.execution_time);
|
||||
EXPECT_EQ(1, cost.num_ops_total);
|
||||
EXPECT_FALSE(cost.inaccurate);
|
||||
EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(OpLevelCostEstimatorTest, TestGatherCosts) {
|
||||
OpContext op_context;
|
||||
SetCpuDevice(&op_context.op_info);
|
||||
|
Loading…
Reference in New Issue
Block a user