Add PinToHostOptimizer to grappler: force small ops to happen on CPU (instead of
GPU). This avoids many unnecessary CPU<->GPU memcpy and syncs. PiperOrigin-RevId: 214108484
This commit is contained in:
parent
e317152dad
commit
ca552d54ac
@ -77,6 +77,9 @@ TEST(DirectSessionWithTrackingAllocTest, CostModelTest) {
|
|||||||
options.config.mutable_graph_options()
|
options.config.mutable_graph_options()
|
||||||
->mutable_rewrite_options()
|
->mutable_rewrite_options()
|
||||||
->set_min_graph_nodes(-1);
|
->set_min_graph_nodes(-1);
|
||||||
|
options.config.mutable_graph_options()
|
||||||
|
->mutable_rewrite_options()
|
||||||
|
->set_pin_to_host_optimization(RewriterConfig::OFF);
|
||||||
std::unique_ptr<Session> session(NewSession(options));
|
std::unique_ptr<Session> session(NewSession(options));
|
||||||
TF_ASSERT_OK(session->Create(def));
|
TF_ASSERT_OK(session->Create(def));
|
||||||
std::vector<std::pair<string, Tensor>> inputs;
|
std::vector<std::pair<string, Tensor>> inputs;
|
||||||
|
@ -83,6 +83,7 @@ void Cluster::DisableOptimizer(bool disable) {
|
|||||||
rewriter_config->set_memory_optimization(RewriterConfig::NO_MEM_OPT);
|
rewriter_config->set_memory_optimization(RewriterConfig::NO_MEM_OPT);
|
||||||
rewriter_config->set_shape_optimization(RewriterConfig::OFF);
|
rewriter_config->set_shape_optimization(RewriterConfig::OFF);
|
||||||
rewriter_config->set_remapping(RewriterConfig::OFF);
|
rewriter_config->set_remapping(RewriterConfig::OFF);
|
||||||
|
rewriter_config->set_pin_to_host_optimization(RewriterConfig::OFF);
|
||||||
rewriter_config->mutable_auto_parallel()->set_enable(false);
|
rewriter_config->mutable_auto_parallel()->set_enable(false);
|
||||||
rewriter_config->clear_optimizers();
|
rewriter_config->clear_optimizers();
|
||||||
} else {
|
} else {
|
||||||
|
@ -14,11 +14,41 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/core/grappler/graph_view.h"
|
#include "tensorflow/core/grappler/graph_view.h"
|
||||||
|
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||||
#include "tensorflow/core/grappler/utils.h"
|
#include "tensorflow/core/grappler/utils.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace grappler {
|
namespace grappler {
|
||||||
|
|
||||||
|
int OpOutputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id) {
|
||||||
|
for (int output_arg_id = 0; output_arg_id < op.output_arg_size();
|
||||||
|
++output_arg_id) {
|
||||||
|
if (port_id < 0) {
|
||||||
|
return -1;
|
||||||
|
} else if (port_id == 0) {
|
||||||
|
return output_arg_id;
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto& output_arg = op.output_arg(output_arg_id);
|
||||||
|
if (!output_arg.number_attr().empty()) {
|
||||||
|
const int n = node.attr().at(output_arg.number_attr()).i();
|
||||||
|
if (n < 0) {
|
||||||
|
// This should never happen.
|
||||||
|
DCHECK_GE(n, 0);
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
if (port_id < n) {
|
||||||
|
return output_arg_id;
|
||||||
|
}
|
||||||
|
port_id -= n;
|
||||||
|
} else {
|
||||||
|
--port_id;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
GraphView::GraphView(GraphDef* graph) : graph_(graph) {
|
GraphView::GraphView(GraphDef* graph) : graph_(graph) {
|
||||||
for (int i = 0; i < graph_->node_size(); i++) {
|
for (int i = 0; i < graph_->node_size(); i++) {
|
||||||
auto node = graph_->mutable_node(i);
|
auto node = graph_->mutable_node(i);
|
||||||
|
@ -20,11 +20,21 @@ limitations under the License.
|
|||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
#include "tensorflow/core/framework/graph.pb.h"
|
#include "tensorflow/core/framework/graph.pb.h"
|
||||||
#include "tensorflow/core/framework/node_def.pb.h"
|
#include "tensorflow/core/framework/node_def.pb.h"
|
||||||
|
#include "tensorflow/core/framework/op_def.pb.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace grappler {
|
namespace grappler {
|
||||||
|
|
||||||
|
// Map a node/op's output port_id to arg_id.
|
||||||
|
//
|
||||||
|
// The port_id refers to the n-th tensor of the node, while the arg_id refers to
|
||||||
|
// the n-th arg of the op. These two can be different if an op's arg is a list
|
||||||
|
// of tensors.
|
||||||
|
//
|
||||||
|
// We return -1 for any invalid port_id (i.e., no corresponding arg_id).
|
||||||
|
int OpOutputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id);
|
||||||
|
|
||||||
// A utility class to simplify the traversal of a GraphDef.
|
// A utility class to simplify the traversal of a GraphDef.
|
||||||
class GraphView {
|
class GraphView {
|
||||||
public:
|
public:
|
||||||
|
@ -25,6 +25,60 @@ namespace {
|
|||||||
|
|
||||||
class GraphViewTest : public ::testing::Test {};
|
class GraphViewTest : public ::testing::Test {};
|
||||||
|
|
||||||
|
TEST_F(GraphViewTest, OpOutputPortIdToArgIdShapeN) {
|
||||||
|
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||||
|
Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10});
|
||||||
|
ops::ShapeN b(s.WithOpName("b"), {a, a, a});
|
||||||
|
|
||||||
|
GraphDef graph_def;
|
||||||
|
TF_CHECK_OK(s.ToGraphDef(&graph_def));
|
||||||
|
GraphView graph_view(&graph_def);
|
||||||
|
|
||||||
|
const NodeDef& a_node_def = *graph_view.GetNode("a");
|
||||||
|
const NodeDef& b_node_def = *graph_view.GetNode("b");
|
||||||
|
|
||||||
|
const OpDef* a_op_def = nullptr;
|
||||||
|
const OpDef* b_op_def = nullptr;
|
||||||
|
EXPECT_TRUE(
|
||||||
|
OpRegistry::Global()->LookUpOpDef(a_node_def.op(), &a_op_def).ok());
|
||||||
|
EXPECT_TRUE(
|
||||||
|
OpRegistry::Global()->LookUpOpDef(b_node_def.op(), &b_op_def).ok());
|
||||||
|
|
||||||
|
EXPECT_EQ(0, OpOutputPortIdToArgId(b_node_def, *a_op_def, 0));
|
||||||
|
EXPECT_EQ(-1, OpOutputPortIdToArgId(b_node_def, *a_op_def, 1));
|
||||||
|
|
||||||
|
EXPECT_EQ(0, OpOutputPortIdToArgId(b_node_def, *b_op_def, 0));
|
||||||
|
EXPECT_EQ(0, OpOutputPortIdToArgId(b_node_def, *b_op_def, 1));
|
||||||
|
EXPECT_EQ(0, OpOutputPortIdToArgId(b_node_def, *b_op_def, 2));
|
||||||
|
EXPECT_EQ(-1, OpOutputPortIdToArgId(b_node_def, *b_op_def, 3));
|
||||||
|
EXPECT_EQ(-1, OpOutputPortIdToArgId(b_node_def, *b_op_def, 4));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(GraphViewTest, OpOutputPortIdToArgIdSparseSplit) {
|
||||||
|
for (int num_splits : {1, 2}) {
|
||||||
|
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||||
|
Output a = ops::Const<int64>(s.WithOpName("a"), 1, {10, 10});
|
||||||
|
ops::SparseSplit b(s.WithOpName("b"), a, a, a, a, num_splits);
|
||||||
|
|
||||||
|
GraphDef graph_def;
|
||||||
|
TF_CHECK_OK(s.ToGraphDef(&graph_def));
|
||||||
|
GraphView graph_view(&graph_def);
|
||||||
|
|
||||||
|
const NodeDef& b_node_def = *graph_view.GetNode("b");
|
||||||
|
const OpDef* b_op_def = nullptr;
|
||||||
|
EXPECT_TRUE(
|
||||||
|
OpRegistry::Global()->LookUpOpDef(b_node_def.op(), &b_op_def).ok());
|
||||||
|
|
||||||
|
for (int port_id = 0; port_id <= num_splits * 3; ++port_id) {
|
||||||
|
int arg_id = -1;
|
||||||
|
if (port_id < num_splits * 3) {
|
||||||
|
arg_id = port_id / num_splits;
|
||||||
|
}
|
||||||
|
EXPECT_EQ(arg_id, OpOutputPortIdToArgId(b_node_def, *b_op_def, port_id));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(GraphViewTest, BasicGraph) {
|
TEST_F(GraphViewTest, BasicGraph) {
|
||||||
TrivialTestGraphInputYielder fake_input(4, 2, 2, false, {"/CPU:0", "/GPU:0"});
|
TrivialTestGraphInputYielder fake_input(4, 2, 2, false, {"/CPU:0", "/GPU:0"});
|
||||||
GrapplerItem item;
|
GrapplerItem item;
|
||||||
|
@ -518,6 +518,7 @@ cc_library(
|
|||||||
":loop_optimizer",
|
":loop_optimizer",
|
||||||
":memory_optimizer",
|
":memory_optimizer",
|
||||||
":model_pruner",
|
":model_pruner",
|
||||||
|
":pin_to_host_optimizer",
|
||||||
":remapper",
|
":remapper",
|
||||||
":scoped_allocator_optimizer",
|
":scoped_allocator_optimizer",
|
||||||
":shape_optimizer",
|
":shape_optimizer",
|
||||||
@ -883,3 +884,41 @@ tf_cc_test(
|
|||||||
"//tensorflow/core/grappler/utils:grappler_test",
|
"//tensorflow/core/grappler/utils:grappler_test",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "pin_to_host_optimizer",
|
||||||
|
srcs = ["pin_to_host_optimizer.cc"],
|
||||||
|
hdrs = [
|
||||||
|
"pin_to_host_optimizer.h",
|
||||||
|
],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
":graph_optimizer",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//tensorflow/core/grappler:graph_view",
|
||||||
|
"//tensorflow/core/grappler:grappler_item",
|
||||||
|
"//tensorflow/core/grappler:op_types",
|
||||||
|
"//tensorflow/core/grappler:utils",
|
||||||
|
"//tensorflow/core/grappler/costs:graph_properties",
|
||||||
|
"//tensorflow/core/grappler/utils:frame",
|
||||||
|
"//tensorflow/core/grappler/utils:symbolic_shapes",
|
||||||
|
"//tensorflow/core/grappler/utils:topological_sort",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_cuda_cc_test(
|
||||||
|
name = "pin_to_host_optimizer_test",
|
||||||
|
srcs = ["pin_to_host_optimizer_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":pin_to_host_optimizer",
|
||||||
|
"//tensorflow/cc:cc_ops",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
"//tensorflow/core:testlib",
|
||||||
|
"//tensorflow/core/grappler:grappler_item",
|
||||||
|
"//tensorflow/core/grappler/utils:grappler_test",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@ -29,6 +29,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/grappler/optimizers/loop_optimizer.h"
|
#include "tensorflow/core/grappler/optimizers/loop_optimizer.h"
|
||||||
#include "tensorflow/core/grappler/optimizers/memory_optimizer.h"
|
#include "tensorflow/core/grappler/optimizers/memory_optimizer.h"
|
||||||
#include "tensorflow/core/grappler/optimizers/model_pruner.h"
|
#include "tensorflow/core/grappler/optimizers/model_pruner.h"
|
||||||
|
#include "tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h"
|
||||||
#include "tensorflow/core/grappler/optimizers/remapper.h"
|
#include "tensorflow/core/grappler/optimizers/remapper.h"
|
||||||
#include "tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.h"
|
#include "tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.h"
|
||||||
#include "tensorflow/core/grappler/optimizers/shape_optimizer.h"
|
#include "tensorflow/core/grappler/optimizers/shape_optimizer.h"
|
||||||
@ -105,6 +106,7 @@ std::unique_ptr<GraphOptimizer> MetaOptimizer::MakeNewOptimizer(
|
|||||||
MK_OPT("scoped_allocator",
|
MK_OPT("scoped_allocator",
|
||||||
new ScopedAllocatorOptimizer(cfg_.scoped_allocator_optimization(),
|
new ScopedAllocatorOptimizer(cfg_.scoped_allocator_optimization(),
|
||||||
cfg_.scoped_allocator_opts()));
|
cfg_.scoped_allocator_opts()));
|
||||||
|
MK_OPT("small_op", new PinToHostOptimizer(cfg_.pin_to_host_optimization()));
|
||||||
|
|
||||||
return std::unique_ptr<GraphOptimizer>();
|
return std::unique_ptr<GraphOptimizer>();
|
||||||
}
|
}
|
||||||
@ -133,6 +135,9 @@ Status MetaOptimizer::InitializeOptimizers(
|
|||||||
if (cfg_.remapping() != RewriterConfig::OFF) {
|
if (cfg_.remapping() != RewriterConfig::OFF) {
|
||||||
optimizers->push_back(MakeUnique<Remapper>(cfg_.remapping()));
|
optimizers->push_back(MakeUnique<Remapper>(cfg_.remapping()));
|
||||||
}
|
}
|
||||||
|
if (cfg_.pin_to_host_optimization() == RewriterConfig::ON) {
|
||||||
|
optimizers->push_back(MakeUnique<PinToHostOptimizer>());
|
||||||
|
}
|
||||||
if (cfg_.arithmetic_optimization() != RewriterConfig::OFF) {
|
if (cfg_.arithmetic_optimization() != RewriterConfig::OFF) {
|
||||||
optimizers->push_back(
|
optimizers->push_back(
|
||||||
MakeUnique<ArithmeticOptimizer>(cfg_.arithmetic_optimization()));
|
MakeUnique<ArithmeticOptimizer>(cfg_.arithmetic_optimization()));
|
||||||
@ -468,6 +473,7 @@ bool MetaOptimizerEnabled(const RewriterConfig& cfg) {
|
|||||||
cfg.memory_optimization() != RewriterConfig::NO_MEM_OPT ||
|
cfg.memory_optimization() != RewriterConfig::NO_MEM_OPT ||
|
||||||
cfg.debug_stripper() == RewriterConfig::ON ||
|
cfg.debug_stripper() == RewriterConfig::ON ||
|
||||||
cfg.scoped_allocator_optimization() == RewriterConfig::ON ||
|
cfg.scoped_allocator_optimization() == RewriterConfig::ON ||
|
||||||
|
cfg.pin_to_host_optimization() == RewriterConfig::ON ||
|
||||||
!cfg.optimizers().empty() || !cfg.custom_optimizers().empty();
|
!cfg.optimizers().empty() || !cfg.custom_optimizers().empty();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
218
tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc
Normal file
218
tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc
Normal file
@ -0,0 +1,218 @@
|
|||||||
|
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/op.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||||
|
#include "tensorflow/core/framework/types.h"
|
||||||
|
#include "tensorflow/core/grappler/graph_view.h"
|
||||||
|
#include "tensorflow/core/grappler/grappler_item.h"
|
||||||
|
#include "tensorflow/core/grappler/op_types.h"
|
||||||
|
#include "tensorflow/core/grappler/utils/symbolic_shapes.h"
|
||||||
|
#include "tensorflow/core/grappler/utils/topological_sort.h"
|
||||||
|
#include "tensorflow/core/lib/core/error_codes.pb.h"
|
||||||
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
#include "tensorflow/core/lib/strings/str_util.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace grappler {
|
||||||
|
namespace internal {
|
||||||
|
|
||||||
|
// TODO(williamchan): Change this constant to be something smarter, maybe
|
||||||
|
// dynamically determined.
|
||||||
|
constexpr int64 kTensorMaxSize = 64;
|
||||||
|
|
||||||
|
// Find KernelDef for `node`.
|
||||||
|
Status TryFindKernelDef(const NodeDef& node, const KernelDef** kdef) {
|
||||||
|
// Try find KernelDef for node.device, else GPU or CPU.
|
||||||
|
for (const DeviceType& device :
|
||||||
|
{node.device().c_str(), DEVICE_GPU, DEVICE_CPU}) {
|
||||||
|
Status s = FindKernelDef(device, node, kdef, nullptr);
|
||||||
|
if (s.ok()) {
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return errors::NotFound("Could not find KernelDef for op: ", node.op());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if all node's inputs are pinned to CPU memory.
|
||||||
|
bool AreAllNodeInputsPinnedToHost(const GraphView& graph, const NodeDef& node) {
|
||||||
|
// Loop through all the inputs excluding the controlling nodes.
|
||||||
|
for (const GraphView::OutputPort& fanin : graph.GetFanins(node, false)) {
|
||||||
|
// Check if (the fanin) op's device is on CPU.
|
||||||
|
if (str_util::StrContains(fanin.node->device(), DEVICE_CPU)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if (the fanin) op's output port is pinned to HostMemory.
|
||||||
|
const OpDef* fanin_odef = nullptr;
|
||||||
|
Status s = OpRegistry::Global()->LookUpOpDef(fanin.node->op(), &fanin_odef);
|
||||||
|
if (!s.ok()) {
|
||||||
|
LOG(INFO) << "Could not find OpDef for : " << fanin.node->op();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int output_arg_id =
|
||||||
|
OpOutputPortIdToArgId(*fanin.node, *fanin_odef, fanin.port_id);
|
||||||
|
if (output_arg_id < 0) {
|
||||||
|
LOG(WARNING) << "Invalid port: " << fanin.port_id << "!\n"
|
||||||
|
<< node.DebugString() << "\n"
|
||||||
|
<< fanin_odef->DebugString();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
const KernelDef* fanin_kdef = nullptr;
|
||||||
|
s = TryFindKernelDef(*fanin.node, &fanin_kdef);
|
||||||
|
if (!s.ok()) {
|
||||||
|
LOG(INFO) << "Could not find KernelDef for : " << fanin.node->op();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool fanin_pinned = false;
|
||||||
|
for (const string& host_memory_arg : fanin_kdef->host_memory_arg()) {
|
||||||
|
if (fanin_odef->output_arg(output_arg_id).name() == host_memory_arg) {
|
||||||
|
fanin_pinned = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!fanin_pinned) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IsTensorIntegerAndSmall(const OpInfo::TensorProperties& prop) {
|
||||||
|
// Check if Tensor is integer and small size.
|
||||||
|
|
||||||
|
// Check type to be int32 or int64.
|
||||||
|
if (prop.dtype() != DataType::DT_INT32 &&
|
||||||
|
prop.dtype() != DataType::DT_INT64) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check size known and small.
|
||||||
|
const int64 size = NumCoefficients(prop.shape());
|
||||||
|
if (size < 0 || size > kTensorMaxSize) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool AreAllNodeInputsAndOutputsIntsAndSmall(const GraphProperties& properties,
|
||||||
|
const NodeDef& node) {
|
||||||
|
for (const auto& prop : properties.GetInputProperties(node.name())) {
|
||||||
|
if (!IsTensorIntegerAndSmall(prop)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const auto& prop : properties.GetOutputProperties(node.name())) {
|
||||||
|
if (!IsTensorIntegerAndSmall(prop)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
string TryFindHostDevice(const gtl::FlatSet<string>& devices,
|
||||||
|
bool has_device_cpu, const string& device) {
|
||||||
|
// Force this node onto the CPU.
|
||||||
|
if (device.empty() && has_device_cpu) {
|
||||||
|
return "/device:CPU:0";
|
||||||
|
} else if (str_util::StrContains(device, DEVICE_GPU)) {
|
||||||
|
// Sometimes the cluster can have:
|
||||||
|
// devices = {"/device:CPU:0", "/device:XLA_GPU:0"}
|
||||||
|
// and we need to handle them properly.
|
||||||
|
for (const auto& device_match :
|
||||||
|
{std::pair<string, string>("GPU", "CPU:0"),
|
||||||
|
std::pair<string, string>("/device", "/device:CPU:0")}) {
|
||||||
|
const string device_host =
|
||||||
|
strings::StrCat(device.substr(0, device.rfind(device_match.first)),
|
||||||
|
device_match.second);
|
||||||
|
if (devices.find(device_host) != devices.end()) {
|
||||||
|
return device_host;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// We couldn't find an appropriate Host device, return original device.
|
||||||
|
return device;
|
||||||
|
}
|
||||||
|
} // end namespace internal
|
||||||
|
|
||||||
|
Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||||
|
GraphDef* optimized_graph) {
|
||||||
|
*optimized_graph = item.graph;
|
||||||
|
|
||||||
|
GraphProperties properties(item);
|
||||||
|
bool has_properties = false;
|
||||||
|
GraphView graph(optimized_graph);
|
||||||
|
|
||||||
|
gtl::FlatSet<string> devices;
|
||||||
|
if (cluster) {
|
||||||
|
const std::vector<string> device_names = cluster->GetDeviceNames();
|
||||||
|
devices.insert(device_names.begin(), device_names.end());
|
||||||
|
} else {
|
||||||
|
devices = {"/device:CPU:0"};
|
||||||
|
}
|
||||||
|
|
||||||
|
const bool has_device_cpu = devices.find("/device:CPU:0") != devices.end();
|
||||||
|
|
||||||
|
// Topologically sort the graph, so that we traverse the nodes in order. This
|
||||||
|
// will help us discover producer->consumer chains of Host ops.
|
||||||
|
TF_RETURN_IF_ERROR(TopologicalSort(optimized_graph));
|
||||||
|
for (auto& node : *optimized_graph->mutable_node()) {
|
||||||
|
// Check if node already on CPU.
|
||||||
|
if (str_util::StrContains(node.device(), DEVICE_CPU)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check the node can be run on CPU.
|
||||||
|
Status s = FindKernelDef(DEVICE_CPU, node, nullptr, nullptr);
|
||||||
|
if (!s.ok()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check all input's are pinned to CPU.
|
||||||
|
if (!internal::AreAllNodeInputsPinnedToHost(graph, node)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!has_properties) {
|
||||||
|
// This is an expensive call, call it lazily.
|
||||||
|
TF_RETURN_IF_ERROR(properties.InferStatically(false));
|
||||||
|
has_properties = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check all inputs and outputs are integers and small.
|
||||||
|
if (!internal::AreAllNodeInputsAndOutputsIntsAndSmall(properties, node)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try and swap the device to Host.
|
||||||
|
node.set_device(
|
||||||
|
internal::TryFindHostDevice(devices, has_device_cpu, node.device()));
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // end namespace grappler
|
||||||
|
} // end namespace tensorflow
|
62
tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h
Normal file
62
tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_PIN_TO_HOST_OPTIMIZER_H_
|
||||||
|
#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_PIN_TO_HOST_OPTIMIZER_H_
|
||||||
|
|
||||||
|
#include <unordered_set>
|
||||||
|
#include "tensorflow/core/grappler/costs/graph_properties.h"
|
||||||
|
#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
|
||||||
|
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||||
|
#include "tensorflow/core/protobuf/rewriter_config.pb.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace grappler {
|
||||||
|
namespace internal {
|
||||||
|
// Try and find an appropriate Host device in `devices` given `device`.
|
||||||
|
string TryFindHostDevice(const gtl::FlatSet<string>& devices,
|
||||||
|
bool has_device_cpu, const string& device);
|
||||||
|
} // end namespace internal
|
||||||
|
|
||||||
|
// Optimize TensorFlow ops that should be swapped into the CPU to avoid
|
||||||
|
// excessive cpu<->gpu memcpy/sync.
|
||||||
|
//
|
||||||
|
// TODO(williamchan): The current heuristic will swap any small integer Const to
|
||||||
|
// CPU. This may cause a problem cpu->cpu->gpu wherein the original behaviour of
|
||||||
|
// gpu->gpu->gpu may have been better/faster. We should probably fix this.
|
||||||
|
class PinToHostOptimizer : public GraphOptimizer {
|
||||||
|
public:
|
||||||
|
PinToHostOptimizer() : opt_level_(RewriterConfig::DEFAULT) {}
|
||||||
|
explicit PinToHostOptimizer(RewriterConfig::Toggle opt_level)
|
||||||
|
: opt_level_(opt_level) {}
|
||||||
|
|
||||||
|
~PinToHostOptimizer() override {}
|
||||||
|
|
||||||
|
string name() const override { return "pin_to_host_optimizer"; };
|
||||||
|
|
||||||
|
Status Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||||
|
GraphDef* optimized_graph) override;
|
||||||
|
|
||||||
|
void Feedback(Cluster* cluster, const GrapplerItem& item,
|
||||||
|
const GraphDef& optimized_graph, double result) override {}
|
||||||
|
|
||||||
|
private:
|
||||||
|
RewriterConfig::Toggle opt_level_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // end namespace grappler
|
||||||
|
} // end namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_PIN_TO_HOST_OPTIMIZER_H_
|
@ -0,0 +1,162 @@
|
|||||||
|
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h"
|
||||||
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
|
#include "tensorflow/core/framework/node_def.pb.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||||
|
#include "tensorflow/core/grappler/grappler_item.h"
|
||||||
|
#include "tensorflow/core/grappler/utils/grappler_test.h"
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace grappler {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
class PinToHostOptimizerTest : public GrapplerTest {};
|
||||||
|
|
||||||
|
TEST_F(PinToHostOptimizerTest, TryFindHostDevice) {
|
||||||
|
gtl::FlatSet<string> devices = {};
|
||||||
|
EXPECT_EQ("ABC", internal::TryFindHostDevice(devices, false, "ABC"));
|
||||||
|
|
||||||
|
devices = {"/device:CPU:0", "/device:XLA_GPU:0"};
|
||||||
|
EXPECT_EQ(internal::TryFindHostDevice(devices, true, ""), "/device:CPU:0");
|
||||||
|
EXPECT_EQ(internal::TryFindHostDevice(devices, true, "/device:XLA_GPU:0"),
|
||||||
|
"/device:CPU:0");
|
||||||
|
EXPECT_EQ(internal::TryFindHostDevice(devices, true, "/device:XLA_GPU:*"),
|
||||||
|
"/device:CPU:0");
|
||||||
|
|
||||||
|
devices = {"/device:XLA_CPU:0", "/device:XLA_GPU:0"};
|
||||||
|
EXPECT_EQ(internal::TryFindHostDevice(devices, false, ""), "");
|
||||||
|
EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:0"),
|
||||||
|
"/device:XLA_CPU:0");
|
||||||
|
EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:*"),
|
||||||
|
"/device:XLA_CPU:0");
|
||||||
|
|
||||||
|
devices = {"/device:XLA_GPU:0"};
|
||||||
|
EXPECT_EQ(internal::TryFindHostDevice(devices, false, ""), "");
|
||||||
|
EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:0"),
|
||||||
|
"/device:XLA_GPU:0");
|
||||||
|
EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:*"),
|
||||||
|
"/device:XLA_GPU:*");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(PinToHostOptimizerTest, OptimizeSmallOpsToHost) {
|
||||||
|
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||||
|
Output a = ops::Const(s.WithOpName("a"), 1, {1024, 1024});
|
||||||
|
Output c = ops::Shape(s.WithOpName("c"), a);
|
||||||
|
Output d = ops::Const(s.WithOpName("d"), 0, {1});
|
||||||
|
Output e = ops::ReduceProd(s.WithOpName("e"), c, d);
|
||||||
|
|
||||||
|
GrapplerItem item;
|
||||||
|
item.fetch = {"a", "c", "d", "e"};
|
||||||
|
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||||
|
|
||||||
|
auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
|
||||||
|
|
||||||
|
GraphDef output;
|
||||||
|
PinToHostOptimizer optimizer(RewriterConfig::ON);
|
||||||
|
TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||||
|
|
||||||
|
auto tensors = EvaluateNodes(item.graph, item.fetch);
|
||||||
|
EXPECT_EQ(tensors_expected.size(), tensors.size());
|
||||||
|
for (int i = 0; i < tensors.size(); ++i) {
|
||||||
|
test::ExpectTensorEqual<int32>(tensors[i], tensors_expected[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
int found = 0;
|
||||||
|
for (const NodeDef& node : output.node()) {
|
||||||
|
if (node.name() == "a" || node.name() == "c") {
|
||||||
|
EXPECT_TRUE(node.device().empty());
|
||||||
|
} else if (node.name() == "d" || node.name() == "e") {
|
||||||
|
EXPECT_EQ(node.device(), "/device:CPU:0");
|
||||||
|
}
|
||||||
|
++found;
|
||||||
|
}
|
||||||
|
EXPECT_EQ(found, 4);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(PinToHostOptimizerTest, TopologicalSort) {
|
||||||
|
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||||
|
Output a = ops::Const(s.WithOpName("a"), 1, {1024, 1024});
|
||||||
|
Output c = ops::Shape(s.WithOpName("c"), a);
|
||||||
|
Output d = ops::Const(s.WithOpName("d"), 0, {1});
|
||||||
|
Output e = ops::ReduceProd(s.WithOpName("e"), c, d);
|
||||||
|
|
||||||
|
GrapplerItem item;
|
||||||
|
item.fetch = {"a", "c", "d", "e"};
|
||||||
|
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||||
|
|
||||||
|
auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
|
||||||
|
|
||||||
|
// Reverse the graph, and hence rely on the optimizer to sort it.
|
||||||
|
std::reverse(item.graph.mutable_node()->begin(),
|
||||||
|
item.graph.mutable_node()->end());
|
||||||
|
|
||||||
|
GraphDef output;
|
||||||
|
PinToHostOptimizer optimizer(RewriterConfig::ON);
|
||||||
|
TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||||
|
|
||||||
|
auto tensors = EvaluateNodes(item.graph, item.fetch);
|
||||||
|
EXPECT_EQ(tensors_expected.size(), tensors.size());
|
||||||
|
for (int i = 0; i < tensors.size(); ++i) {
|
||||||
|
test::ExpectTensorEqual<int32>(tensors[i], tensors_expected[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
int found = 0;
|
||||||
|
for (const NodeDef& node : output.node()) {
|
||||||
|
if (node.name() == "a" || node.name() == "c") {
|
||||||
|
EXPECT_TRUE(node.device().empty());
|
||||||
|
} else if (node.name() == "d" || node.name() == "e") {
|
||||||
|
EXPECT_EQ(node.device(), "/device:CPU:0");
|
||||||
|
}
|
||||||
|
++found;
|
||||||
|
}
|
||||||
|
EXPECT_EQ(found, 4);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(PinToHostOptimizerTest, PortIdToArgId) {
|
||||||
|
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||||
|
Output a = ops::Const(s.WithOpName("a"), 1, {1, 2, 3});
|
||||||
|
ops::ShapeN b(s.WithOpName("b"), {a, a, a});
|
||||||
|
|
||||||
|
GrapplerItem item;
|
||||||
|
item.fetch = {"a", "b"};
|
||||||
|
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||||
|
|
||||||
|
auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
|
||||||
|
|
||||||
|
GraphDef output;
|
||||||
|
PinToHostOptimizer optimizer(RewriterConfig::ON);
|
||||||
|
TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||||
|
|
||||||
|
auto tensors = EvaluateNodes(item.graph, item.fetch);
|
||||||
|
EXPECT_EQ(tensors_expected.size(), tensors.size());
|
||||||
|
for (int i = 0; i < tensors.size(); ++i) {
|
||||||
|
test::ExpectTensorEqual<int32>(tensors[i], tensors_expected[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
int found = 0;
|
||||||
|
for (const NodeDef& node : output.node()) {
|
||||||
|
EXPECT_EQ(node.device(), "/device:CPU:0");
|
||||||
|
++found;
|
||||||
|
}
|
||||||
|
EXPECT_EQ(found, 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace grappler
|
||||||
|
} // namespace tensorflow
|
@ -30,13 +30,16 @@ GrapplerTest::GrapplerTest() {
|
|||||||
// optimizations interfering in the comparison.
|
// optimizations interfering in the comparison.
|
||||||
RewriterConfig* cfg =
|
RewriterConfig* cfg =
|
||||||
options_.config.mutable_graph_options()->mutable_rewrite_options();
|
options_.config.mutable_graph_options()->mutable_rewrite_options();
|
||||||
cfg->set_constant_folding(RewriterConfig::OFF);
|
// TODO(rmlarsen): Add utility to generate config w/ all optimizers turned
|
||||||
|
// off.
|
||||||
cfg->set_arithmetic_optimization(RewriterConfig::OFF);
|
cfg->set_arithmetic_optimization(RewriterConfig::OFF);
|
||||||
|
cfg->set_constant_folding(RewriterConfig::OFF);
|
||||||
|
cfg->set_debug_stripper(RewriterConfig::OFF);
|
||||||
cfg->set_dependency_optimization(RewriterConfig::OFF);
|
cfg->set_dependency_optimization(RewriterConfig::OFF);
|
||||||
cfg->set_loop_optimization(RewriterConfig::OFF);
|
|
||||||
cfg->set_function_optimization(RewriterConfig::OFF);
|
cfg->set_function_optimization(RewriterConfig::OFF);
|
||||||
cfg->set_layout_optimizer(RewriterConfig::OFF);
|
cfg->set_layout_optimizer(RewriterConfig::OFF);
|
||||||
cfg->set_debug_stripper(RewriterConfig::OFF);
|
cfg->set_loop_optimization(RewriterConfig::OFF);
|
||||||
|
cfg->set_pin_to_host_optimization(RewriterConfig::OFF);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<Tensor> GrapplerTest::EvaluateNodes(
|
std::vector<Tensor> GrapplerTest::EvaluateNodes(
|
||||||
|
@ -75,6 +75,8 @@ message RewriterConfig {
|
|||||||
// Try to allocate some independent Op outputs contiguously in order to
|
// Try to allocate some independent Op outputs contiguously in order to
|
||||||
// merge or eliminate downstream Ops (off by default).
|
// merge or eliminate downstream Ops (off by default).
|
||||||
Toggle scoped_allocator_optimization = 15;
|
Toggle scoped_allocator_optimization = 15;
|
||||||
|
// Force small ops onto the CPU (default is OFF).
|
||||||
|
Toggle pin_to_host_optimization = 18;
|
||||||
|
|
||||||
// Controls how many times we run the optimizers in meta optimizer (default
|
// Controls how many times we run the optimizers in meta optimizer (default
|
||||||
// is once).
|
// is once).
|
||||||
|
@ -57,7 +57,8 @@ def no_rewrite_session_config():
|
|||||||
disable_model_pruning=True,
|
disable_model_pruning=True,
|
||||||
constant_folding=rewriter_config_pb2.RewriterConfig.OFF,
|
constant_folding=rewriter_config_pb2.RewriterConfig.OFF,
|
||||||
arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF,
|
arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF,
|
||||||
dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF)
|
dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF,
|
||||||
|
pin_to_host_optimization=rewriter_config_pb2.RewriterConfig.OFF)
|
||||||
|
|
||||||
graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
|
graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
|
||||||
return config_pb2.ConfigProto(graph_options=graph_options)
|
return config_pb2.ConfigProto(graph_options=graph_options)
|
||||||
|
@ -45,6 +45,7 @@ class ReconstructNonDebugGraphTest(test_util.TensorFlowTestCase):
|
|||||||
def _no_rewrite_session_config(self):
|
def _no_rewrite_session_config(self):
|
||||||
rewriter_config = rewriter_config_pb2.RewriterConfig(
|
rewriter_config = rewriter_config_pb2.RewriterConfig(
|
||||||
dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF,
|
dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF,
|
||||||
|
pin_to_host_optimization=rewriter_config_pb2.RewriterConfig.OFF,
|
||||||
min_graph_nodes=-1)
|
min_graph_nodes=-1)
|
||||||
graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
|
graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
|
||||||
return config_pb2.ConfigProto(graph_options=graph_options)
|
return config_pb2.ConfigProto(graph_options=graph_options)
|
||||||
@ -156,7 +157,7 @@ class ReconstructNonDebugGraphTest(test_util.TensorFlowTestCase):
|
|||||||
sess, cond, expected_output=21.0)
|
sess, cond, expected_output=21.0)
|
||||||
|
|
||||||
def testReconstructGraphWithWhileLoop(self):
|
def testReconstructGraphWithWhileLoop(self):
|
||||||
with session.Session() as sess:
|
with session.Session(config=self._no_rewrite_session_config()) as sess:
|
||||||
loop_body = lambda i: math_ops.add(i, 2)
|
loop_body = lambda i: math_ops.add(i, 2)
|
||||||
loop_cond = lambda i: math_ops.less(i, 16)
|
loop_cond = lambda i: math_ops.less(i, 16)
|
||||||
i = constant_op.constant(10, name="i")
|
i = constant_op.constant(10, name="i")
|
||||||
|
@ -1934,6 +1934,8 @@ class TensorFlowTestCase(googletest.TestCase):
|
|||||||
rewriter_config_pb2.RewriterConfig.OFF)
|
rewriter_config_pb2.RewriterConfig.OFF)
|
||||||
config.graph_options.rewrite_options.arithmetic_optimization = (
|
config.graph_options.rewrite_options.arithmetic_optimization = (
|
||||||
rewriter_config_pb2.RewriterConfig.OFF)
|
rewriter_config_pb2.RewriterConfig.OFF)
|
||||||
|
config.graph_options.rewrite_options.pin_to_host_optimization = (
|
||||||
|
rewriter_config_pb2.RewriterConfig.OFF)
|
||||||
return config
|
return config
|
||||||
|
|
||||||
return ErrorLoggingSession(graph=graph, config=prepare_config(config))
|
return ErrorLoggingSession(graph=graph, config=prepare_config(config))
|
||||||
|
Loading…
Reference in New Issue
Block a user