- Added input argument for aggressive shape inference mode in AnalyticalCostEstimator.

- Unified the logic on VirtualCluster in AnalyticalCostEstimator and VirtualCluster.

PiperOrigin-RevId: 237718010
This commit is contained in:
Peter Ma 2019-03-10 17:14:06 -07:00 committed by TensorFlower Gardener
parent bff52bb804
commit 7273a08672
9 changed files with 61 additions and 83 deletions

View File

@ -81,6 +81,7 @@ cc_library(
"//tensorflow/core:core_cpu_base",
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler/costs:analytical_cost_estimator",
"//tensorflow/core/grappler/costs:op_level_cost_estimator",
"//tensorflow/core/grappler/costs:virtual_scheduler",
],

View File

@ -14,32 +14,33 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
#include "tensorflow/core/framework/cost_graph.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/grappler/clusters/utils.h"
#include "tensorflow/core/grappler/costs/op_level_cost_estimator.h"
#include "tensorflow/core/grappler/costs/virtual_scheduler.h"
namespace tensorflow {
namespace grappler {
VirtualCluster::VirtualCluster(
const std::unordered_map<string, DeviceProperties>& devices)
: Cluster(0),
node_estimator_(new OpLevelCostEstimator()),
node_manager_(new FirstReadyManager()) {
devices_ = devices;
}
: VirtualCluster(devices, absl::make_unique<OpLevelCostEstimator>(),
ReadyNodeManagerFactory("FirstReady")) {}
VirtualCluster::VirtualCluster(
const std::unordered_map<string, DeviceProperties>& devices,
std::unique_ptr<OpLevelCostEstimator> node_estimator,
std::unique_ptr<ReadyNodeManager> node_manager)
: Cluster(0),
node_estimator_(std::move(node_estimator)),
node_manager_(std::move(node_manager)) {
: Cluster(0) {
devices_ = devices;
// Note that we do not use aggressive shape inference to preserve unknown
// shapes from the input graph.
estimator_ = absl::make_unique<AnalyticalCostEstimator>(
this, std::move(node_estimator), std::move(node_manager),
/*use_static_shapes=*/true, /*use_aggressive_shape_inference=*/false);
}
VirtualCluster::VirtualCluster(const DeviceSet* device_set)
@ -66,19 +67,13 @@ Status VirtualCluster::Run(const GraphDef& graph,
const std::vector<std::pair<string, Tensor>>& feed,
const std::vector<string>& fetch,
RunMetadata* metadata) {
// Initialize a virtual scheduler to process the graph. Make sure to use
// static shape inference to prevent the scheduler from calling the Run
// method on the cluster and creating an infinite loop.
// Initializes an analytical cost estimator to estimate the graph cost. Makes
// sure to use static shape inference to prevent the virtual scheduler from
// calling the Run method on the cluster and creating an infinite loop.
GrapplerItem item;
item.graph = graph;
item.feed = feed;
item.fetch = fetch;
// Note that we do not use aggressive shape inference to preserve unknown
// shapes from the input graph.
VirtualScheduler scheduler(/*use_static_shapes=*/true,
/*use_aggressive_shape_inference=*/false, this,
node_manager_.get());
TF_RETURN_IF_ERROR(scheduler.Init(&item));
if (metadata) {
metadata->clear_step_stats();
@ -86,45 +81,14 @@ Status VirtualCluster::Run(const GraphDef& graph,
metadata->clear_partition_graphs();
}
Costs node_costs;
int node_id = 0;
do {
OpContext op_context = scheduler.GetCurrNode();
node_costs = node_estimator_->PredictCosts(op_context);
if (metadata) {
CostGraphDef::Node* cost_node =
metadata->mutable_cost_graph()->add_node();
const string& op_name = op_context.name;
cost_node->set_id(node_id++);
cost_node->set_name(op_name);
cost_node->set_device(op_context.device_name);
cost_node->set_compute_cost(
node_costs.execution_time.asMicroSeconds().count());
cost_node->set_compute_time(
node_costs.compute_time.asMicroSeconds().count());
cost_node->set_memory_time(
node_costs.memory_time.asMicroSeconds().count());
for (const auto& output : op_context.op_info.outputs()) {
auto output_info = cost_node->add_output_info();
output_info->set_dtype(output.dtype());
*output_info->mutable_shape() = output.shape();
int64 size = DataTypeSize(output.dtype());
for (const auto& dim : output.shape().dim()) {
size *= std::max<int64>(1, dim.size());
}
output_info->set_size(size);
}
}
} while (scheduler.MarkCurrNodeExecuted(node_costs));
if (metadata) {
scheduler.Summary(metadata);
}
TF_RETURN_IF_ERROR(estimator_->Initialize(item));
Costs ignored_costs;
TF_RETURN_IF_ERROR(
estimator_->PredictCosts(item.graph, metadata, &ignored_costs));
const std::unordered_map<string, DeviceProperties>& device = GetDevices();
std::unordered_map<string, int64> peak_mem_usage =
scheduler.GetPeakMemoryUsage();
estimator_->GetScheduler()->GetPeakMemoryUsage();
for (const auto& mem_usage : peak_mem_usage) {
const string& device_name = mem_usage.first;
auto it = device.find(device_name);

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/grappler/clusters/cluster.h"
#include "tensorflow/core/grappler/costs/analytical_cost_estimator.h"
#include "tensorflow/core/grappler/costs/op_level_cost_estimator.h"
#include "tensorflow/core/grappler/costs/virtual_scheduler.h"
#include "tensorflow/core/protobuf/device_properties.pb.h"
@ -50,9 +51,8 @@ class VirtualCluster : public Cluster {
const DeviceSet* GetDeviceSet() const override { return device_set_; }
private:
std::unique_ptr<OpLevelCostEstimator> node_estimator_;
std::unique_ptr<ReadyNodeManager> node_manager_;
const DeviceSet* device_set_ = nullptr; // Not owned
std::unique_ptr<AnalyticalCostEstimator> estimator_;
const DeviceSet* device_set_ = nullptr;
};
} // end namespace grappler

View File

@ -56,7 +56,7 @@ void AddCostNode(ReadyNodeManager* node_manager, const OpContext& op_context,
(*name_to_id)[node->name()] = node->id();
}
// For nodes we have seen before (e.g. Merge nodes are executed twice by
// VirtualScheduler), the following fields will be overwritten/updated
// VirtualScheduler), the following fields will be overwritten/updated.
node->set_device(op_context.device_name);
node->set_compute_cost(node_costs.execution_time.asMicroSeconds().count());
node->set_compute_time(node_costs.compute_time.asMicroSeconds().count());
@ -67,7 +67,7 @@ void AddCostNode(ReadyNodeManager* node_manager, const OpContext& op_context,
int input_port;
string input_name = ParseNodeName(input, &input_port);
// All inputs should have been seen already unless this is a Merge node
// All inputs should have been seen already unless this is a Merge node.
if (name_to_id->find(input_name) == name_to_id->end()) {
if (!IsMerge(*node_manager->GetCurrNode()))
LOG(ERROR) << "input: " << input
@ -76,7 +76,7 @@ void AddCostNode(ReadyNodeManager* node_manager, const OpContext& op_context,
// For Merge node, some of inputs may not be seen before
// For example, for a typical while loop in tensorflow, Merge node
// will be executed twice by VirtualScheduler (one for Enter, the
// other for NextIteration), so eventually both inputs will be added
// other for NextIteration), so eventually both inputs will be added.
continue;
}
@ -93,30 +93,38 @@ void AddCostNode(ReadyNodeManager* node_manager, const OpContext& op_context,
auto output_info = node->add_output_info();
output_info->set_alias_input_port(-1);
output_info->set_dtype(output.dtype());
auto shape = output_info->mutable_shape();
*shape = output.shape();
*output_info->mutable_shape() = output.shape();
int64 size = DataTypeSize(output.dtype());
for (const auto& dim : output.shape().dim()) {
size *= std::max<int64>(1, dim.size());
}
output_info->set_size(size);
}
}
} // namespace
AnalyticalCostEstimator::AnalyticalCostEstimator(Cluster* cluster,
bool use_static_shapes)
AnalyticalCostEstimator::AnalyticalCostEstimator(
Cluster* cluster, bool use_static_shapes,
bool use_aggressive_shape_inference)
: AnalyticalCostEstimator(
cluster, absl::make_unique<OpLevelCostEstimator>(),
ReadyNodeManagerFactory("FirstReady"), use_static_shapes) {}
ReadyNodeManagerFactory("FirstReady"), use_static_shapes,
use_aggressive_shape_inference) {}
AnalyticalCostEstimator::AnalyticalCostEstimator(
Cluster* cluster, std::unique_ptr<OpLevelCostEstimator> node_estimator,
std::unique_ptr<ReadyNodeManager> node_manager, bool use_static_shapes)
std::unique_ptr<ReadyNodeManager> node_manager, bool use_static_shapes,
bool use_aggressive_shape_inference)
: cluster_(cluster),
node_estimator_(std::move(node_estimator)),
node_manager_(std::move(node_manager)),
use_static_shapes_(use_static_shapes) {
// Use aggressive static shape inference to minimize unknown shapes.
use_static_shapes_(use_static_shapes),
use_aggressive_shape_inference_(use_aggressive_shape_inference) {
scheduler_ = absl::make_unique<VirtualScheduler>(
use_static_shapes_,
/*use_aggressive_shape_inference=*/true, cluster_, node_manager_.get());
use_static_shapes_, use_aggressive_shape_inference_, cluster_,
node_manager_.get());
}
Status AnalyticalCostEstimator::Initialize(const GrapplerItem& item) {
@ -142,7 +150,7 @@ Status AnalyticalCostEstimator::PredictCosts(const GraphDef& optimized_graph,
cost_graph = run_metadata->mutable_cost_graph();
// TODO(pcma): Clear nodes in cost_graph after we make sure we always pass
// in an empty cost_graph (a non-empty but incomplete cost_graph will cause
// problems, e.g., no node_id in cost_graph)
// problems, e.g., no node_id in cost_graph).
for (auto& node : *cost_graph->mutable_node()) {
name_to_cost_node[node.name()] = &node;
}
@ -165,7 +173,7 @@ Status AnalyticalCostEstimator::PredictCosts(const GraphDef& optimized_graph,
<< node_costs.num_ops_with_unknown_shapes << " unknown shapes";
}
// TODO(pcma): Add unit tests for generating CostGraphDef
// TODO(pcma): Add unit tests for generating CostGraphDef.
if (cost_graph) {
AddCostNode(node_manager_.get(), op_context, node_id++, node_costs,
&name_to_cost_node, &name_to_id, cost_graph);

View File

@ -35,15 +35,19 @@ struct GrapplerItem;
// Estimate the cost of running a Grappler item based on the theoretical
// performance of the hardware that will run the model. Note that this
// internally uses aggressive shape inference with static shape inference.
// internally uses static shape inference. An option for aggressive shape
// inference is provided to minimize unknown shapes, and this is only applicable
// with static shape inference.
class AnalyticalCostEstimator : public CostEstimator {
public:
// Does not take ownership of cluster.
AnalyticalCostEstimator(Cluster* cluster, bool use_static_shapes);
AnalyticalCostEstimator(Cluster* cluster, bool use_static_shapes,
bool use_aggressive_shape_inference);
AnalyticalCostEstimator(Cluster* cluster,
std::unique_ptr<OpLevelCostEstimator> node_estimator,
std::unique_ptr<ReadyNodeManager> node_manager,
bool use_static_shapes);
bool use_static_shapes,
bool use_aggressive_shape_inference);
~AnalyticalCostEstimator() override {}
// Initializes the estimator for the specified grappler item.
@ -63,8 +67,10 @@ class AnalyticalCostEstimator : public CostEstimator {
GrapplerItem item_;
std::unique_ptr<OpLevelCostEstimator> node_estimator_;
std::unique_ptr<ReadyNodeManager> node_manager_;
bool use_static_shapes_;
std::unique_ptr<VirtualScheduler> scheduler_;
bool use_static_shapes_;
bool use_aggressive_shape_inference_;
};
} // end namespace grappler

View File

@ -94,7 +94,8 @@ class AnalyticalCostEstimatorTest : public ::testing::Test {
TEST_F(AnalyticalCostEstimatorTest, SimpleTest) {
GrapplerItem item = CreateMiniGraph();
AnalyticalCostEstimator estimator(cluster_.get(), true);
AnalyticalCostEstimator estimator(cluster_.get(), /*use_static_shapes=*/true,
/*use_aggressive_shape_inference=*/true);
TF_ASSERT_OK(estimator.Initialize(item));
RunMetadata run_metadata;

View File

@ -153,7 +153,7 @@ static GCluster TF_NewVirtualCluster(
for (const auto& named_device : named_devices) {
devices[named_device.name()]= named_device.properties();
}
tensorflow::grappler::Cluster*cluster_ =
tensorflow::grappler::Cluster* cluster_ =
new tensorflow::grappler::VirtualCluster(devices);
PyGILState_STATE gstate = PyGILState_Ensure();
tensorflow::Status status = cluster_->Provision();

View File

@ -99,9 +99,7 @@ class ClusterTest(test.TestCase):
type='GPU',
frequency=1000,
num_cores=60,
environment={
'architecture': '7'
})
environment={'architecture': '7'})
named_device = device_properties_pb2.NamedDevice(
properties=device_properties, name='/device:GPU:0')
grappler_cluster = cluster.Cluster(

View File

@ -27,7 +27,8 @@ CostAnalyzer::CostAnalyzer(const GrapplerItem& item, Cluster* cluster,
const string& suffix)
: item_(&item),
measure_estimator_(cluster, 10, 0),
analytical_estimator_(cluster, false),
analytical_estimator_(cluster, /*use_static_shapes=*/false,
/*use_aggressive_shape_inference=*/true),
suffix_(suffix) {}
Status CostAnalyzer::GenerateReport(std::ostream& os, bool per_node_report,
@ -125,7 +126,6 @@ void CostAnalyzer::PreprocessCosts() {
}
}
void CostAnalyzer::SortOpsByTime(std::map<string, OpPerfSummary> ops) {
for (const auto& op : ops) {
ops_.push_back(op.second);