Merge pull request #44794 from xinan-jiang:pr/cost-model
PiperOrigin-RevId: 355879939 Change-Id: If403700f53e68edf3ff3966bf245218aba36e46a
This commit is contained in:
commit
45f36e8699
@ -1906,6 +1906,7 @@ tf_cc_tests(
|
||||
"//tensorflow/core/example:feature_util_test.cc",
|
||||
"//tensorflow/core/graph:algorithm_test.cc",
|
||||
"//tensorflow/core/graph:control_flow_test.cc",
|
||||
"//tensorflow/core/graph:costmodel_test.cc",
|
||||
"//tensorflow/core/graph:edgeset_test.cc",
|
||||
"//tensorflow/core/graph:graph_def_builder_test.cc",
|
||||
"//tensorflow/core/graph:graph_partition_test.cc",
|
||||
|
@ -201,6 +201,7 @@ exports_files(
|
||||
srcs = [
|
||||
"algorithm_test.cc",
|
||||
"control_flow_test.cc",
|
||||
"costmodel_test.cc",
|
||||
"edgeset_test.cc",
|
||||
"graph_def_builder_test.cc",
|
||||
"graph_partition_test.cc",
|
||||
|
@ -15,7 +15,9 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/graph/costmodel.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/allocation_description.pb.h"
|
||||
#include "tensorflow/core/framework/cost_graph.pb.h"
|
||||
#include "tensorflow/core/framework/step_stats.pb.h"
|
||||
@ -479,11 +481,12 @@ void CostModel::AddToCostGraphDef(const Graph* graph,
|
||||
CostGraphDef* cost_graph) const {
|
||||
std::vector<const Edge*> inputs;
|
||||
std::vector<const Edge*> control_inputs;
|
||||
int offset = cost_graph->node_size();
|
||||
for (const Node* n : graph->nodes()) {
|
||||
CostGraphDef::Node* cnode = cost_graph->add_node();
|
||||
cnode->set_name(n->name());
|
||||
cnode->set_device(n->assigned_device_name());
|
||||
cnode->set_id(Id(n));
|
||||
cnode->set_id(GlobalId(n, offset));
|
||||
|
||||
inputs.clear();
|
||||
inputs.resize(n->num_inputs(), nullptr);
|
||||
@ -502,7 +505,7 @@ void CostModel::AddToCostGraphDef(const Graph* graph,
|
||||
|
||||
for (const Edge* e : inputs) {
|
||||
CostGraphDef::Node::InputInfo* input_info = cnode->add_input_info();
|
||||
input_info->set_preceding_node(Id(e->src()));
|
||||
input_info->set_preceding_node(GlobalId(e->src(), offset));
|
||||
input_info->set_preceding_port(e->src_output());
|
||||
}
|
||||
|
||||
@ -528,7 +531,7 @@ void CostModel::AddToCostGraphDef(const Graph* graph,
|
||||
}
|
||||
|
||||
for (const Edge* e : control_inputs) {
|
||||
cnode->add_control_input(Id(e->src()));
|
||||
cnode->add_control_input(GlobalId(e->src(), offset));
|
||||
}
|
||||
|
||||
cnode->set_temporary_memory_size(TempMemorySize(n).value());
|
||||
|
@ -66,6 +66,14 @@ class CostModel {
|
||||
}
|
||||
}
|
||||
|
||||
inline int GlobalId(const Node* n, int offset) const {
|
||||
if (is_global_) {
|
||||
return n->cost_id();
|
||||
} else {
|
||||
return n->id() + offset;
|
||||
}
|
||||
}
|
||||
|
||||
// Initializes cost model for 'g'.
|
||||
void InitFromGraph(const Graph& g);
|
||||
|
||||
|
104
tensorflow/core/graph/costmodel_test.cc
Normal file
104
tensorflow/core/graph/costmodel_test.cc
Normal file
@ -0,0 +1,104 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/graph/costmodel.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/cc/framework/scope.h"
|
||||
#include "tensorflow/core/common_runtime/costmodel_manager.h"
|
||||
#include "tensorflow/core/common_runtime/graph_constructor.h"
|
||||
#include "tensorflow/core/common_runtime/step_stats_collector.h"
|
||||
#include "tensorflow/core/framework/cost_graph.pb.h"
|
||||
#include "tensorflow/core/framework/step_stats.pb.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/util/dump_graph.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
static void InitGraph(const string& s, Graph* graph) {
|
||||
GraphDef graph_def;
|
||||
|
||||
auto parser = protobuf::TextFormat::Parser();
|
||||
CHECK(parser.MergeFromString(s, &graph_def)) << s;
|
||||
GraphConstructorOptions opts;
|
||||
TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, graph));
|
||||
}
|
||||
|
||||
static void GenerateStepStats(Graph* graph, StepStats* step_stats,
|
||||
const string& device_name) {
|
||||
// Fill RunMetadata's step_stats and partition_graphs fields.
|
||||
DeviceStepStats* device_stepstats = step_stats->add_dev_stats();
|
||||
device_stepstats->set_device(device_name);
|
||||
for (const auto& node_def : graph->nodes()) {
|
||||
NodeExecStats* node_stats = device_stepstats->add_node_stats();
|
||||
node_stats->set_node_name(node_def->name());
|
||||
}
|
||||
}
|
||||
|
||||
REGISTER_OP("Input").Output("o: float").SetIsStateful();
|
||||
|
||||
TEST(CostModelTest, GlobalId) {
|
||||
Scope scope = Scope::NewRootScope().ExitOnError();
|
||||
std::unique_ptr<Graph> graph1 =
|
||||
std::unique_ptr<Graph>(new Graph(OpRegistry::Global()));
|
||||
std::unique_ptr<Graph> graph2 =
|
||||
std::unique_ptr<Graph>(new Graph(OpRegistry::Global()));
|
||||
InitGraph(
|
||||
"node { name: 'A1' op: 'Input'}"
|
||||
"node { name: 'B1' op: 'Input'}"
|
||||
"node { name: 'C1' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['A1', 'B1'] }"
|
||||
"node { name: 'D1' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['A1', 'B1'] }",
|
||||
graph1.get());
|
||||
InitGraph(
|
||||
"node { name: 'A2' op: 'Input'}"
|
||||
"node { name: 'B2' op: 'Input'}"
|
||||
"node { name: 'C2' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['A2', 'B2'] }"
|
||||
"node { name: 'D2' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['A2', 'B2'] }",
|
||||
graph2.get());
|
||||
StepStats step_stats;
|
||||
GenerateStepStats(graph1.get(), &step_stats, "DummyDevice1");
|
||||
GenerateStepStats(graph2.get(), &step_stats, "DummyDevice2");
|
||||
StepStatsCollector collector(&step_stats);
|
||||
std::unordered_map<string, const Graph*> device_map;
|
||||
device_map["DummyDevice1"] = graph1.get();
|
||||
device_map["DummyDevice2"] = graph2.get();
|
||||
CostModelManager cost_model_manager;
|
||||
collector.BuildCostModel(&cost_model_manager, device_map);
|
||||
CostGraphDef cost_graph_def;
|
||||
TF_ASSERT_OK(
|
||||
cost_model_manager.AddToCostGraphDef(graph1.get(), &cost_graph_def));
|
||||
TF_ASSERT_OK(
|
||||
cost_model_manager.AddToCostGraphDef(graph2.get(), &cost_graph_def));
|
||||
ASSERT_EQ(cost_graph_def.node_size(), 12);
|
||||
absl::flat_hash_map<int32, const CostGraphDef::Node> ids;
|
||||
for (auto node : cost_graph_def.node()) {
|
||||
int32 index = node.id();
|
||||
auto result = ids.insert({index, node});
|
||||
EXPECT_TRUE(result.second);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -18,6 +18,9 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/util/dump_graph.h"
|
||||
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "absl/strings/match.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/core/lib/strings/proto_serialization.h"
|
||||
@ -209,6 +212,25 @@ string DumpGraphDefToFile(const string& name, GraphDef const& graph_def,
|
||||
return filepath;
|
||||
}
|
||||
|
||||
string DumpCostGraphDefToFile(const string& name, CostGraphDef const& graph_def,
|
||||
const string& dirname) {
|
||||
string filepath;
|
||||
std::unique_ptr<WritableFile> file;
|
||||
Status status = CreateWritableFile(Env::Default(), dirname, name, ".pbtxt",
|
||||
&filepath, &file);
|
||||
if (!status.ok()) {
|
||||
return StrCat("(failed to create writable file: ", status.ToString(), ")");
|
||||
}
|
||||
|
||||
status = WriteTextProtoToUniqueFile(graph_def, file.get());
|
||||
if (!status.ok()) {
|
||||
return StrCat("(failed to dump Graph to '", filepath,
|
||||
"': ", status.ToString(), ")");
|
||||
}
|
||||
LOG(INFO) << "Dumped Graph to " << filepath;
|
||||
return filepath;
|
||||
}
|
||||
|
||||
string DumpGraphToFile(const string& name, Graph const& graph,
|
||||
const FunctionLibraryDefinition* flib_def,
|
||||
const string& dirname) {
|
||||
|
@ -19,6 +19,9 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_CORE_UTIL_DUMP_GRAPH_H_
|
||||
#define TENSORFLOW_CORE_UTIL_DUMP_GRAPH_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/core/framework/cost_graph.pb.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
@ -39,6 +42,10 @@ namespace tensorflow {
|
||||
string DumpGraphDefToFile(const string& name, GraphDef const& graph_def,
|
||||
const string& dirname = "");
|
||||
|
||||
// Similar to DumpGraphDefToFile, use CostGraphDef instead of GraphDef.
|
||||
string DumpCostGraphDefToFile(const string& name, CostGraphDef const& graph_def,
|
||||
const string& dirname = "");
|
||||
|
||||
// Similar to DumpGraphDefToFile, but builds the GraphDef to dump from a 'graph'
|
||||
// and an optional function library 'flib_def'. Returns the file name chosen.
|
||||
string DumpGraphToFile(const string& name, Graph const& graph,
|
||||
|
Loading…
Reference in New Issue
Block a user