- Added input argument for aggressive shape inference mode in AnalyticalCostEstimator.
- Unified the logic on VirtualCluster in AnalyticalCostEstimator and VirtualCluster. PiperOrigin-RevId: 237718010
This commit is contained in:
parent
bff52bb804
commit
7273a08672
@ -81,6 +81,7 @@ cc_library(
|
||||
"//tensorflow/core:core_cpu_base",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/grappler/costs:analytical_cost_estimator",
|
||||
"//tensorflow/core/grappler/costs:op_level_cost_estimator",
|
||||
"//tensorflow/core/grappler/costs:virtual_scheduler",
|
||||
],
|
||||
|
||||
@ -14,32 +14,33 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
|
||||
|
||||
#include "tensorflow/core/framework/cost_graph.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/grappler/clusters/utils.h"
|
||||
#include "tensorflow/core/grappler/costs/op_level_cost_estimator.h"
|
||||
#include "tensorflow/core/grappler/costs/virtual_scheduler.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
VirtualCluster::VirtualCluster(
|
||||
const std::unordered_map<string, DeviceProperties>& devices)
|
||||
: Cluster(0),
|
||||
node_estimator_(new OpLevelCostEstimator()),
|
||||
node_manager_(new FirstReadyManager()) {
|
||||
devices_ = devices;
|
||||
}
|
||||
: VirtualCluster(devices, absl::make_unique<OpLevelCostEstimator>(),
|
||||
ReadyNodeManagerFactory("FirstReady")) {}
|
||||
|
||||
VirtualCluster::VirtualCluster(
|
||||
const std::unordered_map<string, DeviceProperties>& devices,
|
||||
std::unique_ptr<OpLevelCostEstimator> node_estimator,
|
||||
std::unique_ptr<ReadyNodeManager> node_manager)
|
||||
: Cluster(0),
|
||||
node_estimator_(std::move(node_estimator)),
|
||||
node_manager_(std::move(node_manager)) {
|
||||
: Cluster(0) {
|
||||
devices_ = devices;
|
||||
|
||||
// Note that we do not use aggressive shape inference to preserve unknown
|
||||
// shapes from the input graph.
|
||||
estimator_ = absl::make_unique<AnalyticalCostEstimator>(
|
||||
this, std::move(node_estimator), std::move(node_manager),
|
||||
/*use_static_shapes=*/true, /*use_aggressive_shape_inference=*/false);
|
||||
}
|
||||
|
||||
VirtualCluster::VirtualCluster(const DeviceSet* device_set)
|
||||
@ -66,19 +67,13 @@ Status VirtualCluster::Run(const GraphDef& graph,
|
||||
const std::vector<std::pair<string, Tensor>>& feed,
|
||||
const std::vector<string>& fetch,
|
||||
RunMetadata* metadata) {
|
||||
// Initialize a virtual scheduler to process the graph. Make sure to use
|
||||
// static shape inference to prevent the scheduler from calling the Run
|
||||
// method on the cluster and creating an infinite loop.
|
||||
// Initializes an analytical cost estimator to estimate the graph cost. Makes
|
||||
// sure to use static shape inference to prevent the virtual scheduler from
|
||||
// calling the Run method on the cluster and creating an infinite loop.
|
||||
GrapplerItem item;
|
||||
item.graph = graph;
|
||||
item.feed = feed;
|
||||
item.fetch = fetch;
|
||||
// Note that we do not use aggressive shape inference to preserve unknown
|
||||
// shapes from the input graph.
|
||||
VirtualScheduler scheduler(/*use_static_shapes=*/true,
|
||||
/*use_aggressive_shape_inference=*/false, this,
|
||||
node_manager_.get());
|
||||
TF_RETURN_IF_ERROR(scheduler.Init(&item));
|
||||
|
||||
if (metadata) {
|
||||
metadata->clear_step_stats();
|
||||
@ -86,45 +81,14 @@ Status VirtualCluster::Run(const GraphDef& graph,
|
||||
metadata->clear_partition_graphs();
|
||||
}
|
||||
|
||||
Costs node_costs;
|
||||
int node_id = 0;
|
||||
do {
|
||||
OpContext op_context = scheduler.GetCurrNode();
|
||||
node_costs = node_estimator_->PredictCosts(op_context);
|
||||
if (metadata) {
|
||||
CostGraphDef::Node* cost_node =
|
||||
metadata->mutable_cost_graph()->add_node();
|
||||
const string& op_name = op_context.name;
|
||||
cost_node->set_id(node_id++);
|
||||
cost_node->set_name(op_name);
|
||||
cost_node->set_device(op_context.device_name);
|
||||
cost_node->set_compute_cost(
|
||||
node_costs.execution_time.asMicroSeconds().count());
|
||||
cost_node->set_compute_time(
|
||||
node_costs.compute_time.asMicroSeconds().count());
|
||||
cost_node->set_memory_time(
|
||||
node_costs.memory_time.asMicroSeconds().count());
|
||||
for (const auto& output : op_context.op_info.outputs()) {
|
||||
auto output_info = cost_node->add_output_info();
|
||||
output_info->set_dtype(output.dtype());
|
||||
*output_info->mutable_shape() = output.shape();
|
||||
|
||||
int64 size = DataTypeSize(output.dtype());
|
||||
for (const auto& dim : output.shape().dim()) {
|
||||
size *= std::max<int64>(1, dim.size());
|
||||
}
|
||||
output_info->set_size(size);
|
||||
}
|
||||
}
|
||||
} while (scheduler.MarkCurrNodeExecuted(node_costs));
|
||||
|
||||
if (metadata) {
|
||||
scheduler.Summary(metadata);
|
||||
}
|
||||
TF_RETURN_IF_ERROR(estimator_->Initialize(item));
|
||||
Costs ignored_costs;
|
||||
TF_RETURN_IF_ERROR(
|
||||
estimator_->PredictCosts(item.graph, metadata, &ignored_costs));
|
||||
|
||||
const std::unordered_map<string, DeviceProperties>& device = GetDevices();
|
||||
std::unordered_map<string, int64> peak_mem_usage =
|
||||
scheduler.GetPeakMemoryUsage();
|
||||
estimator_->GetScheduler()->GetPeakMemoryUsage();
|
||||
for (const auto& mem_usage : peak_mem_usage) {
|
||||
const string& device_name = mem_usage.first;
|
||||
auto it = device.find(device_name);
|
||||
|
||||
@ -20,6 +20,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/common_runtime/device_set.h"
|
||||
#include "tensorflow/core/grappler/clusters/cluster.h"
|
||||
#include "tensorflow/core/grappler/costs/analytical_cost_estimator.h"
|
||||
#include "tensorflow/core/grappler/costs/op_level_cost_estimator.h"
|
||||
#include "tensorflow/core/grappler/costs/virtual_scheduler.h"
|
||||
#include "tensorflow/core/protobuf/device_properties.pb.h"
|
||||
@ -50,9 +51,8 @@ class VirtualCluster : public Cluster {
|
||||
const DeviceSet* GetDeviceSet() const override { return device_set_; }
|
||||
|
||||
private:
|
||||
std::unique_ptr<OpLevelCostEstimator> node_estimator_;
|
||||
std::unique_ptr<ReadyNodeManager> node_manager_;
|
||||
const DeviceSet* device_set_ = nullptr; // Not owned
|
||||
std::unique_ptr<AnalyticalCostEstimator> estimator_;
|
||||
const DeviceSet* device_set_ = nullptr;
|
||||
};
|
||||
|
||||
} // end namespace grappler
|
||||
|
||||
@ -56,7 +56,7 @@ void AddCostNode(ReadyNodeManager* node_manager, const OpContext& op_context,
|
||||
(*name_to_id)[node->name()] = node->id();
|
||||
}
|
||||
// For nodes we have seen before (e.g. Merge nodes are executed twice by
|
||||
// VirtualScheduler), the following fields will be overwritten/updated
|
||||
// VirtualScheduler), the following fields will be overwritten/updated.
|
||||
node->set_device(op_context.device_name);
|
||||
node->set_compute_cost(node_costs.execution_time.asMicroSeconds().count());
|
||||
node->set_compute_time(node_costs.compute_time.asMicroSeconds().count());
|
||||
@ -67,7 +67,7 @@ void AddCostNode(ReadyNodeManager* node_manager, const OpContext& op_context,
|
||||
int input_port;
|
||||
string input_name = ParseNodeName(input, &input_port);
|
||||
|
||||
// All inputs should have been seen already unless this is a Merge node
|
||||
// All inputs should have been seen already unless this is a Merge node.
|
||||
if (name_to_id->find(input_name) == name_to_id->end()) {
|
||||
if (!IsMerge(*node_manager->GetCurrNode()))
|
||||
LOG(ERROR) << "input: " << input
|
||||
@ -76,7 +76,7 @@ void AddCostNode(ReadyNodeManager* node_manager, const OpContext& op_context,
|
||||
// For Merge node, some of inputs may not be seen before
|
||||
// For example, for a typical while loop in tensorflow, Merge node
|
||||
// will be executed twice by VirtualScheduler (one for Enter, the
|
||||
// other for NextIteration), so eventually both inputs will be added
|
||||
// other for NextIteration), so eventually both inputs will be added.
|
||||
continue;
|
||||
}
|
||||
|
||||
@ -93,30 +93,38 @@ void AddCostNode(ReadyNodeManager* node_manager, const OpContext& op_context,
|
||||
auto output_info = node->add_output_info();
|
||||
output_info->set_alias_input_port(-1);
|
||||
output_info->set_dtype(output.dtype());
|
||||
auto shape = output_info->mutable_shape();
|
||||
*shape = output.shape();
|
||||
*output_info->mutable_shape() = output.shape();
|
||||
|
||||
int64 size = DataTypeSize(output.dtype());
|
||||
for (const auto& dim : output.shape().dim()) {
|
||||
size *= std::max<int64>(1, dim.size());
|
||||
}
|
||||
output_info->set_size(size);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
AnalyticalCostEstimator::AnalyticalCostEstimator(Cluster* cluster,
|
||||
bool use_static_shapes)
|
||||
AnalyticalCostEstimator::AnalyticalCostEstimator(
|
||||
Cluster* cluster, bool use_static_shapes,
|
||||
bool use_aggressive_shape_inference)
|
||||
: AnalyticalCostEstimator(
|
||||
cluster, absl::make_unique<OpLevelCostEstimator>(),
|
||||
ReadyNodeManagerFactory("FirstReady"), use_static_shapes) {}
|
||||
ReadyNodeManagerFactory("FirstReady"), use_static_shapes,
|
||||
use_aggressive_shape_inference) {}
|
||||
|
||||
AnalyticalCostEstimator::AnalyticalCostEstimator(
|
||||
Cluster* cluster, std::unique_ptr<OpLevelCostEstimator> node_estimator,
|
||||
std::unique_ptr<ReadyNodeManager> node_manager, bool use_static_shapes)
|
||||
std::unique_ptr<ReadyNodeManager> node_manager, bool use_static_shapes,
|
||||
bool use_aggressive_shape_inference)
|
||||
: cluster_(cluster),
|
||||
node_estimator_(std::move(node_estimator)),
|
||||
node_manager_(std::move(node_manager)),
|
||||
use_static_shapes_(use_static_shapes) {
|
||||
// Use aggressive static shape inference to minimize unknown shapes.
|
||||
use_static_shapes_(use_static_shapes),
|
||||
use_aggressive_shape_inference_(use_aggressive_shape_inference) {
|
||||
scheduler_ = absl::make_unique<VirtualScheduler>(
|
||||
use_static_shapes_,
|
||||
/*use_aggressive_shape_inference=*/true, cluster_, node_manager_.get());
|
||||
use_static_shapes_, use_aggressive_shape_inference_, cluster_,
|
||||
node_manager_.get());
|
||||
}
|
||||
|
||||
Status AnalyticalCostEstimator::Initialize(const GrapplerItem& item) {
|
||||
@ -142,7 +150,7 @@ Status AnalyticalCostEstimator::PredictCosts(const GraphDef& optimized_graph,
|
||||
cost_graph = run_metadata->mutable_cost_graph();
|
||||
// TODO(pcma): Clear nodes in cost_graph after we make sure we always pass
|
||||
// in an empty cost_graph (a non-empty but incomplete cost_graph will cause
|
||||
// problems, e.g., no node_id in cost_graph)
|
||||
// problems, e.g., no node_id in cost_graph).
|
||||
for (auto& node : *cost_graph->mutable_node()) {
|
||||
name_to_cost_node[node.name()] = &node;
|
||||
}
|
||||
@ -165,7 +173,7 @@ Status AnalyticalCostEstimator::PredictCosts(const GraphDef& optimized_graph,
|
||||
<< node_costs.num_ops_with_unknown_shapes << " unknown shapes";
|
||||
}
|
||||
|
||||
// TODO(pcma): Add unit tests for generating CostGraphDef
|
||||
// TODO(pcma): Add unit tests for generating CostGraphDef.
|
||||
if (cost_graph) {
|
||||
AddCostNode(node_manager_.get(), op_context, node_id++, node_costs,
|
||||
&name_to_cost_node, &name_to_id, cost_graph);
|
||||
|
||||
@ -35,15 +35,19 @@ struct GrapplerItem;
|
||||
|
||||
// Estimate the cost of running a Grappler item based on the theoretical
|
||||
// performance of the hardware that will run the model. Note that this
|
||||
// internally uses aggressive shape inference with static shape inference.
|
||||
// internally uses static shape inference. An option for aggressive shape
|
||||
// inference is provided to minimize unknown shapes, and this is only applicable
|
||||
// with static shape inference.
|
||||
class AnalyticalCostEstimator : public CostEstimator {
|
||||
public:
|
||||
// Does not take ownership of cluster.
|
||||
AnalyticalCostEstimator(Cluster* cluster, bool use_static_shapes);
|
||||
AnalyticalCostEstimator(Cluster* cluster, bool use_static_shapes,
|
||||
bool use_aggressive_shape_inference);
|
||||
AnalyticalCostEstimator(Cluster* cluster,
|
||||
std::unique_ptr<OpLevelCostEstimator> node_estimator,
|
||||
std::unique_ptr<ReadyNodeManager> node_manager,
|
||||
bool use_static_shapes);
|
||||
bool use_static_shapes,
|
||||
bool use_aggressive_shape_inference);
|
||||
~AnalyticalCostEstimator() override {}
|
||||
|
||||
// Initializes the estimator for the specified grappler item.
|
||||
@ -63,8 +67,10 @@ class AnalyticalCostEstimator : public CostEstimator {
|
||||
GrapplerItem item_;
|
||||
std::unique_ptr<OpLevelCostEstimator> node_estimator_;
|
||||
std::unique_ptr<ReadyNodeManager> node_manager_;
|
||||
bool use_static_shapes_;
|
||||
std::unique_ptr<VirtualScheduler> scheduler_;
|
||||
|
||||
bool use_static_shapes_;
|
||||
bool use_aggressive_shape_inference_;
|
||||
};
|
||||
|
||||
} // end namespace grappler
|
||||
|
||||
@ -94,7 +94,8 @@ class AnalyticalCostEstimatorTest : public ::testing::Test {
|
||||
TEST_F(AnalyticalCostEstimatorTest, SimpleTest) {
|
||||
GrapplerItem item = CreateMiniGraph();
|
||||
|
||||
AnalyticalCostEstimator estimator(cluster_.get(), true);
|
||||
AnalyticalCostEstimator estimator(cluster_.get(), /*use_static_shapes=*/true,
|
||||
/*use_aggressive_shape_inference=*/true);
|
||||
TF_ASSERT_OK(estimator.Initialize(item));
|
||||
|
||||
RunMetadata run_metadata;
|
||||
|
||||
@ -153,7 +153,7 @@ static GCluster TF_NewVirtualCluster(
|
||||
for (const auto& named_device : named_devices) {
|
||||
devices[named_device.name()]= named_device.properties();
|
||||
}
|
||||
tensorflow::grappler::Cluster*cluster_ =
|
||||
tensorflow::grappler::Cluster* cluster_ =
|
||||
new tensorflow::grappler::VirtualCluster(devices);
|
||||
PyGILState_STATE gstate = PyGILState_Ensure();
|
||||
tensorflow::Status status = cluster_->Provision();
|
||||
|
||||
@ -99,9 +99,7 @@ class ClusterTest(test.TestCase):
|
||||
type='GPU',
|
||||
frequency=1000,
|
||||
num_cores=60,
|
||||
environment={
|
||||
'architecture': '7'
|
||||
})
|
||||
environment={'architecture': '7'})
|
||||
named_device = device_properties_pb2.NamedDevice(
|
||||
properties=device_properties, name='/device:GPU:0')
|
||||
grappler_cluster = cluster.Cluster(
|
||||
|
||||
@ -27,7 +27,8 @@ CostAnalyzer::CostAnalyzer(const GrapplerItem& item, Cluster* cluster,
|
||||
const string& suffix)
|
||||
: item_(&item),
|
||||
measure_estimator_(cluster, 10, 0),
|
||||
analytical_estimator_(cluster, false),
|
||||
analytical_estimator_(cluster, /*use_static_shapes=*/false,
|
||||
/*use_aggressive_shape_inference=*/true),
|
||||
suffix_(suffix) {}
|
||||
|
||||
Status CostAnalyzer::GenerateReport(std::ostream& os, bool per_node_report,
|
||||
@ -125,7 +126,6 @@ void CostAnalyzer::PreprocessCosts() {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void CostAnalyzer::SortOpsByTime(std::map<string, OpPerfSummary> ops) {
|
||||
for (const auto& op : ops) {
|
||||
ops_.push_back(op.second);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user