Return minimum costs for all persistent ops: Const and Vars.
Currently, PredictCosts() returns minimum costs for Const, Var, and VarV2, but not other var ops. Also, the logic was incorrect in that it sets minimum cost to compute time, but zero to execution time. This CL fixes these. PiperOrigin-RevId: 232510666
This commit is contained in:
parent
e1f2db44f4
commit
f722aee786
@ -102,7 +102,7 @@ TEST_F(AnalyticalCostEstimatorTest, SimpleTest) {
|
|||||||
Costs summary;
|
Costs summary;
|
||||||
TF_ASSERT_OK(estimator.PredictCosts(item.graph, &run_metadata, &summary));
|
TF_ASSERT_OK(estimator.PredictCosts(item.graph, &run_metadata, &summary));
|
||||||
|
|
||||||
EXPECT_EQ(Costs::NanoSeconds(9151), summary.execution_time);
|
EXPECT_EQ(Costs::NanoSeconds(9157), summary.execution_time);
|
||||||
// Note there are totally 17 nodes (RandomUniform creates 2 nodes), but
|
// Note there are totally 17 nodes (RandomUniform creates 2 nodes), but
|
||||||
// grappler will not process "label", therefore we have 15 here instead
|
// grappler will not process "label", therefore we have 15 here instead
|
||||||
EXPECT_EQ(15, summary.num_ops_total);
|
EXPECT_EQ(15, summary.num_ops_total);
|
||||||
|
@ -27,7 +27,6 @@ namespace tensorflow {
|
|||||||
namespace grappler {
|
namespace grappler {
|
||||||
|
|
||||||
constexpr int kOpsPerMac = 2;
|
constexpr int kOpsPerMac = 2;
|
||||||
constexpr char kConst[] = "Const";
|
|
||||||
constexpr char kGuaranteeConst[] = "GuaranteeConst";
|
constexpr char kGuaranteeConst[] = "GuaranteeConst";
|
||||||
constexpr char kConv2d[] = "Conv2D";
|
constexpr char kConv2d[] = "Conv2D";
|
||||||
constexpr char kConv2dBackpropFilter[] = "Conv2DBackpropFilter";
|
constexpr char kConv2dBackpropFilter[] = "Conv2DBackpropFilter";
|
||||||
@ -50,8 +49,6 @@ constexpr char kSqueeze[] = "Squeeze";
|
|||||||
constexpr char kRecv[] = "_Recv";
|
constexpr char kRecv[] = "_Recv";
|
||||||
constexpr char kSend[] = "_Send";
|
constexpr char kSend[] = "_Send";
|
||||||
constexpr char kBatchMatMul[] = "BatchMatMul";
|
constexpr char kBatchMatMul[] = "BatchMatMul";
|
||||||
constexpr char kVariable[] = "Variable";
|
|
||||||
constexpr char kVariableV2[] = "VariableV2";
|
|
||||||
constexpr char kRank[] = "Rank";
|
constexpr char kRank[] = "Rank";
|
||||||
constexpr char kShape[] = "Shape";
|
constexpr char kShape[] = "Shape";
|
||||||
constexpr char kShapeN[] = "ShapeN";
|
constexpr char kShapeN[] = "ShapeN";
|
||||||
@ -68,6 +65,13 @@ constexpr char kAvgPoolGrad[] = "AvgPoolGrad";
|
|||||||
constexpr char kFusedBatchNorm[] = "FusedBatchNorm";
|
constexpr char kFusedBatchNorm[] = "FusedBatchNorm";
|
||||||
constexpr char kFusedBatchNormGrad[] = "FusedBatchNormGrad";
|
constexpr char kFusedBatchNormGrad[] = "FusedBatchNormGrad";
|
||||||
constexpr char kQuantizedMatMulV2[] = "QuantizedMatMulV2";
|
constexpr char kQuantizedMatMulV2[] = "QuantizedMatMulV2";
|
||||||
|
// Persistent ops.
|
||||||
|
constexpr char kConst[] = "Const";
|
||||||
|
constexpr char kVariable[] = "Variable";
|
||||||
|
constexpr char kVariableV2[] = "VariableV2";
|
||||||
|
constexpr char kAutoReloadVariable[] = "AutoReloadVariable";
|
||||||
|
constexpr char kVarHandleOp[] = "VarHandleOp";
|
||||||
|
constexpr char kReadVariableOp[] = "ReadVariableOp";
|
||||||
|
|
||||||
static const Costs::Duration kMinComputeTime(1);
|
static const Costs::Duration kMinComputeTime(1);
|
||||||
|
|
||||||
@ -259,10 +263,6 @@ OpLevelCostEstimator::OpLevelCostEstimator() {
|
|||||||
{kRecv, wrap(&OpLevelCostEstimator::PredictIdentity)},
|
{kRecv, wrap(&OpLevelCostEstimator::PredictIdentity)},
|
||||||
{kSend, wrap(&OpLevelCostEstimator::PredictIdentity)},
|
{kSend, wrap(&OpLevelCostEstimator::PredictIdentity)},
|
||||||
|
|
||||||
{kConst, wrap(&OpLevelCostEstimator::PredictVariable)},
|
|
||||||
{kVariable, wrap(&OpLevelCostEstimator::PredictVariable)},
|
|
||||||
{kVariableV2, wrap(&OpLevelCostEstimator::PredictVariable)},
|
|
||||||
|
|
||||||
{kRank, wrap(&OpLevelCostEstimator::PredictMetadata)},
|
{kRank, wrap(&OpLevelCostEstimator::PredictMetadata)},
|
||||||
{kShape, wrap(&OpLevelCostEstimator::PredictMetadata)},
|
{kShape, wrap(&OpLevelCostEstimator::PredictMetadata)},
|
||||||
{kShapeN, wrap(&OpLevelCostEstimator::PredictMetadata)},
|
{kShapeN, wrap(&OpLevelCostEstimator::PredictMetadata)},
|
||||||
@ -276,6 +276,11 @@ OpLevelCostEstimator::OpLevelCostEstimator() {
|
|||||||
wrap(&OpLevelCostEstimator::PredictFusedBatchNormGrad)},
|
wrap(&OpLevelCostEstimator::PredictFusedBatchNormGrad)},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
persistent_ops_ = {
|
||||||
|
kConst, kVariable, kVariableV2, kAutoReloadVariable,
|
||||||
|
kVarHandleOp, kReadVariableOp,
|
||||||
|
};
|
||||||
|
|
||||||
#define EIGEN_COST(X) Eigen::internal::functor_traits<Eigen::internal::X>::Cost
|
#define EIGEN_COST(X) Eigen::internal::functor_traits<Eigen::internal::X>::Cost
|
||||||
|
|
||||||
// Quantize = apply min and max bounds, multiply by scale factor and round.
|
// Quantize = apply min and max bounds, multiply by scale factor and round.
|
||||||
@ -363,7 +368,18 @@ OpLevelCostEstimator::OpLevelCostEstimator() {
|
|||||||
Costs OpLevelCostEstimator::PredictCosts(const OpContext& op_context) const {
|
Costs OpLevelCostEstimator::PredictCosts(const OpContext& op_context) const {
|
||||||
const auto& op_info = op_context.op_info;
|
const auto& op_info = op_context.op_info;
|
||||||
auto it = device_cost_impl_.find(op_info.op());
|
auto it = device_cost_impl_.find(op_info.op());
|
||||||
if (it == device_cost_impl_.end()) {
|
if (it != device_cost_impl_.end()) {
|
||||||
|
std::function<Costs(const OpContext&)> estimator = it->second;
|
||||||
|
Costs costs = estimator(op_context);
|
||||||
|
VLOG(1) << "Operation " << op_info.op() << " takes "
|
||||||
|
<< costs.execution_time.count() << " ns.";
|
||||||
|
return costs;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (persistent_ops_.find(op_info.op()) != persistent_ops_.end()) {
|
||||||
|
return PredictVariable(op_context);
|
||||||
|
}
|
||||||
|
|
||||||
if (elementwise_ops_.find(op_info.op()) != elementwise_ops_.end()) {
|
if (elementwise_ops_.find(op_info.op()) != elementwise_ops_.end()) {
|
||||||
return PredictCwiseOp(op_context);
|
return PredictCwiseOp(op_context);
|
||||||
}
|
}
|
||||||
@ -371,13 +387,6 @@ Costs OpLevelCostEstimator::PredictCosts(const OpContext& op_context) const {
|
|||||||
VLOG(1) << "Missing accurate estimator for op: " << op_info.op();
|
VLOG(1) << "Missing accurate estimator for op: " << op_info.op();
|
||||||
|
|
||||||
return PredictCostOfAnUnknownOp(op_context);
|
return PredictCostOfAnUnknownOp(op_context);
|
||||||
}
|
|
||||||
|
|
||||||
std::function<Costs(const OpContext&)> estimator = it->second;
|
|
||||||
Costs costs = estimator(op_context);
|
|
||||||
VLOG(1) << "Operation " << op_info.op() << " takes "
|
|
||||||
<< costs.execution_time.count() << " ns.";
|
|
||||||
return costs;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
DeviceInfo OpLevelCostEstimator::GetDeviceInfo(
|
DeviceInfo OpLevelCostEstimator::GetDeviceInfo(
|
||||||
@ -1240,7 +1249,7 @@ Costs OpLevelCostEstimator::PredictVariable(const OpContext& op_context) const {
|
|||||||
result.num_ops_with_unknown_shapes = result.inaccurate;
|
result.num_ops_with_unknown_shapes = result.inaccurate;
|
||||||
|
|
||||||
result.compute_time = kMinComputeTime;
|
result.compute_time = kMinComputeTime;
|
||||||
result.execution_time = result.execution_time;
|
result.execution_time = result.compute_time;
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -193,6 +193,7 @@ class OpLevelCostEstimator {
|
|||||||
// If true, assume compute and memory overlap; hence, the op cost is max of
|
// If true, assume compute and memory overlap; hence, the op cost is max of
|
||||||
// compute_time and memory_time, insteaf of sum of those two.
|
// compute_time and memory_time, insteaf of sum of those two.
|
||||||
bool compute_memory_overlap_;
|
bool compute_memory_overlap_;
|
||||||
|
std::set<string> persistent_ops_;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
friend class OpLevelCostEstimatorTest;
|
friend class OpLevelCostEstimatorTest;
|
||||||
|
@ -499,6 +499,26 @@ class OpLevelCostEstimatorTest : public ::testing::Test {
|
|||||||
OpLevelCostEstimator estimator_;
|
OpLevelCostEstimator estimator_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
TEST_F(OpLevelCostEstimatorTest, TestPersistentOpCosts) {
|
||||||
|
OpContext op_context;
|
||||||
|
SetCpuDevice(&op_context.op_info);
|
||||||
|
std::unordered_set<string> persisent_ops = {
|
||||||
|
"Const", "Variable", "VariableV2", "AutoReloadVariable",
|
||||||
|
"VarHandleOp", "ReadVariableOp",
|
||||||
|
};
|
||||||
|
// Minmum cost for all persistent ops.
|
||||||
|
for (const auto& op : persisent_ops) {
|
||||||
|
op_context.op_info.set_op(op);
|
||||||
|
auto cost = estimator_.PredictCosts(op_context);
|
||||||
|
EXPECT_EQ(Costs::Duration(0), cost.memory_time);
|
||||||
|
EXPECT_EQ(Costs::Duration(1), cost.compute_time);
|
||||||
|
EXPECT_EQ(Costs::Duration(1), cost.execution_time);
|
||||||
|
EXPECT_EQ(1, cost.num_ops_total);
|
||||||
|
EXPECT_FALSE(cost.inaccurate);
|
||||||
|
EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(OpLevelCostEstimatorTest, TestGatherCosts) {
|
TEST_F(OpLevelCostEstimatorTest, TestGatherCosts) {
|
||||||
OpContext op_context;
|
OpContext op_context;
|
||||||
SetCpuDevice(&op_context.op_info);
|
SetCpuDevice(&op_context.op_info);
|
||||||
|
Loading…
Reference in New Issue
Block a user