Remove AnalyticalCostEstimator from tf-sim

This CL removes AnalyticalCostEstimator from tf-sim and uses VirtualScheduler
instead. It also moves CombineCostsAndUpdateExecutionTime from OpLevelCostEstimator to grappler/costs/utils.

PiperOrigin-RevId: 286460843
Change-Id: Ib5e27cf4e5885da2d1c650a892ae578ad3a84154
This commit is contained in:
Andiry Xu 2019-12-19 14:10:23 -08:00 committed by TensorFlower Gardener
parent 079a0bcc55
commit bd05d8197c
5 changed files with 23 additions and 19 deletions

View File

@ -136,6 +136,7 @@ tf_cuda_library(
hdrs = ["utils.h"],
visibility = ["//visibility:public"],
deps = [
":cost_estimator",
"//third_party/eigen3",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
@ -289,6 +290,7 @@ cc_library(
deps = [
":cost_estimator",
":op_context",
":utils",
"//third_party/eigen3",
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/grappler/clusters/utils.h"
#include "tensorflow/core/grappler/costs/utils.h"
namespace tensorflow {
namespace grappler {
@ -659,7 +660,7 @@ Costs OpLevelCostEstimator::PredictOpCountBasedCost(
Costs::NanoSeconds(intermediate_read_time);
costs.intermediate_memory_write_time =
Costs::NanoSeconds(intermediate_write_time);
CombineCostsAndUpdateExecutionTime(&costs);
CombineCostsAndUpdateExecutionTime(compute_memory_overlap_, &costs);
return costs;
}
@ -1715,7 +1716,7 @@ Costs OpLevelCostEstimator::PredictFusedOp(
fused_cost.intermediate_memory_time += op_cost.intermediate_memory_time;
}
CombineCostsAndUpdateExecutionTime(&fused_cost);
CombineCostsAndUpdateExecutionTime(compute_memory_overlap_, &fused_cost);
return fused_cost;
}
@ -2050,17 +2051,5 @@ Costs OpLevelCostEstimator::PredictFusedBatchNormGrad(
costs.max_memory = total_output_size;
return costs;
}
void OpLevelCostEstimator::CombineCostsAndUpdateExecutionTime(
Costs* costs) const {
if (compute_memory_overlap_) {
costs->execution_time =
std::max(costs->intermediate_memory_time,
std::max(costs->compute_time, costs->memory_time));
} else {
costs->execution_time = costs->compute_time + costs->memory_time +
costs->intermediate_memory_time;
}
}
} // end namespace grappler
} // end namespace tensorflow

View File

@ -194,11 +194,6 @@ class OpLevelCostEstimator {
static OpInfo::TensorProperties DescribeTensor(
DataType type, const std::vector<int64>& dims);
// This method calculates the execution time depending on whether IO can
// overlap with computation. It assumes the memory and the compute times have
// already been calculated.
void CombineCostsAndUpdateExecutionTime(Costs* costs) const;
protected:
std::map<string, int> elementwise_ops_;
typedef std::function<Costs(const OpContext& op_context)> CostImpl;

View File

@ -504,5 +504,16 @@ string GetStatsStringFromRunMetadata(const RunMetadata& run_metadata,
return output.str();
}
void CombineCostsAndUpdateExecutionTime(bool compute_memory_overlap,
Costs* costs) {
if (compute_memory_overlap) {
costs->execution_time =
std::max(costs->intermediate_memory_time,
std::max(costs->compute_time, costs->memory_time));
} else {
costs->execution_time = costs->compute_time + costs->memory_time +
costs->intermediate_memory_time;
}
}
} // end namespace grappler
} // end namespace tensorflow

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/graph/types.h"
#include "tensorflow/core/grappler/costs/cost_estimator.h"
#include "tensorflow/core/grappler/costs/op_performance_data.pb.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/protobuf/config.pb.h"
@ -119,6 +120,12 @@ string GetDeviceClass(const string& device_name);
string GetStatsStringFromRunMetadata(const RunMetadata& run_metadata,
bool verbosity);
// This method calculates the execution time depending on whether IO can
// overlap with computation. It assumes the memory and the compute times have
// already been calculated.
void CombineCostsAndUpdateExecutionTime(bool compute_memory_overlap,
Costs* costs);
} // end namespace grappler
} // end namespace tensorflow