Remove AnalyticalCostEstimator from tf-sim
This CL removes AnalyticalCostEstimator from tf-sim and uses VirtualScheduler instead. It also moves CombineCostsAndUpdateExecutionTime from OpLevelCostEstimator to grappler/costs/utils. PiperOrigin-RevId: 286460843 Change-Id: Ib5e27cf4e5885da2d1c650a892ae578ad3a84154
This commit is contained in:
parent
079a0bcc55
commit
bd05d8197c
tensorflow/core/grappler/costs
@ -136,6 +136,7 @@ tf_cuda_library(
|
||||
hdrs = ["utils.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":cost_estimator",
|
||||
"//third_party/eigen3",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:graph",
|
||||
@ -289,6 +290,7 @@ cc_library(
|
||||
deps = [
|
||||
":cost_estimator",
|
||||
":op_context",
|
||||
":utils",
|
||||
"//third_party/eigen3",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/grappler/clusters/utils.h"
|
||||
#include "tensorflow/core/grappler/costs/utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
@ -659,7 +660,7 @@ Costs OpLevelCostEstimator::PredictOpCountBasedCost(
|
||||
Costs::NanoSeconds(intermediate_read_time);
|
||||
costs.intermediate_memory_write_time =
|
||||
Costs::NanoSeconds(intermediate_write_time);
|
||||
CombineCostsAndUpdateExecutionTime(&costs);
|
||||
CombineCostsAndUpdateExecutionTime(compute_memory_overlap_, &costs);
|
||||
return costs;
|
||||
}
|
||||
|
||||
@ -1715,7 +1716,7 @@ Costs OpLevelCostEstimator::PredictFusedOp(
|
||||
fused_cost.intermediate_memory_time += op_cost.intermediate_memory_time;
|
||||
}
|
||||
|
||||
CombineCostsAndUpdateExecutionTime(&fused_cost);
|
||||
CombineCostsAndUpdateExecutionTime(compute_memory_overlap_, &fused_cost);
|
||||
return fused_cost;
|
||||
}
|
||||
|
||||
@ -2050,17 +2051,5 @@ Costs OpLevelCostEstimator::PredictFusedBatchNormGrad(
|
||||
costs.max_memory = total_output_size;
|
||||
return costs;
|
||||
}
|
||||
|
||||
void OpLevelCostEstimator::CombineCostsAndUpdateExecutionTime(
|
||||
Costs* costs) const {
|
||||
if (compute_memory_overlap_) {
|
||||
costs->execution_time =
|
||||
std::max(costs->intermediate_memory_time,
|
||||
std::max(costs->compute_time, costs->memory_time));
|
||||
} else {
|
||||
costs->execution_time = costs->compute_time + costs->memory_time +
|
||||
costs->intermediate_memory_time;
|
||||
}
|
||||
}
|
||||
} // end namespace grappler
|
||||
} // end namespace tensorflow
|
||||
|
@ -194,11 +194,6 @@ class OpLevelCostEstimator {
|
||||
static OpInfo::TensorProperties DescribeTensor(
|
||||
DataType type, const std::vector<int64>& dims);
|
||||
|
||||
// This method calculates the execution time depending on whether IO can
|
||||
// overlap with computation. It assumes the memory and the compute times have
|
||||
// already been calculated.
|
||||
void CombineCostsAndUpdateExecutionTime(Costs* costs) const;
|
||||
|
||||
protected:
|
||||
std::map<string, int> elementwise_ops_;
|
||||
typedef std::function<Costs(const OpContext& op_context)> CostImpl;
|
||||
|
@ -504,5 +504,16 @@ string GetStatsStringFromRunMetadata(const RunMetadata& run_metadata,
|
||||
return output.str();
|
||||
}
|
||||
|
||||
void CombineCostsAndUpdateExecutionTime(bool compute_memory_overlap,
|
||||
Costs* costs) {
|
||||
if (compute_memory_overlap) {
|
||||
costs->execution_time =
|
||||
std::max(costs->intermediate_memory_time,
|
||||
std::max(costs->compute_time, costs->memory_time));
|
||||
} else {
|
||||
costs->execution_time = costs->compute_time + costs->memory_time +
|
||||
costs->intermediate_memory_time;
|
||||
}
|
||||
}
|
||||
} // end namespace grappler
|
||||
} // end namespace tensorflow
|
||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
#include "tensorflow/core/graph/types.h"
|
||||
#include "tensorflow/core/grappler/costs/cost_estimator.h"
|
||||
#include "tensorflow/core/grappler/costs/op_performance_data.pb.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/protobuf/config.pb.h"
|
||||
@ -119,6 +120,12 @@ string GetDeviceClass(const string& device_name);
|
||||
string GetStatsStringFromRunMetadata(const RunMetadata& run_metadata,
|
||||
bool verbosity);
|
||||
|
||||
// This method calculates the execution time depending on whether IO can
|
||||
// overlap with computation. It assumes the memory and the compute times have
|
||||
// already been calculated.
|
||||
void CombineCostsAndUpdateExecutionTime(bool compute_memory_overlap,
|
||||
Costs* costs);
|
||||
|
||||
} // end namespace grappler
|
||||
} // end namespace tensorflow
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user