Add PinToHostOptimizer to grappler: force small ops to happen on CPU (instead of
GPU). This avoids many unnecessary CPU<->GPU memcpy and syncs. PiperOrigin-RevId: 214108484
This commit is contained in:
parent
e317152dad
commit
ca552d54ac
@ -77,6 +77,9 @@ TEST(DirectSessionWithTrackingAllocTest, CostModelTest) {
|
||||
options.config.mutable_graph_options()
|
||||
->mutable_rewrite_options()
|
||||
->set_min_graph_nodes(-1);
|
||||
options.config.mutable_graph_options()
|
||||
->mutable_rewrite_options()
|
||||
->set_pin_to_host_optimization(RewriterConfig::OFF);
|
||||
std::unique_ptr<Session> session(NewSession(options));
|
||||
TF_ASSERT_OK(session->Create(def));
|
||||
std::vector<std::pair<string, Tensor>> inputs;
|
||||
|
@ -83,6 +83,7 @@ void Cluster::DisableOptimizer(bool disable) {
|
||||
rewriter_config->set_memory_optimization(RewriterConfig::NO_MEM_OPT);
|
||||
rewriter_config->set_shape_optimization(RewriterConfig::OFF);
|
||||
rewriter_config->set_remapping(RewriterConfig::OFF);
|
||||
rewriter_config->set_pin_to_host_optimization(RewriterConfig::OFF);
|
||||
rewriter_config->mutable_auto_parallel()->set_enable(false);
|
||||
rewriter_config->clear_optimizers();
|
||||
} else {
|
||||
|
@ -14,11 +14,41 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/graph_view.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
int OpOutputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id) {
|
||||
for (int output_arg_id = 0; output_arg_id < op.output_arg_size();
|
||||
++output_arg_id) {
|
||||
if (port_id < 0) {
|
||||
return -1;
|
||||
} else if (port_id == 0) {
|
||||
return output_arg_id;
|
||||
}
|
||||
|
||||
const auto& output_arg = op.output_arg(output_arg_id);
|
||||
if (!output_arg.number_attr().empty()) {
|
||||
const int n = node.attr().at(output_arg.number_attr()).i();
|
||||
if (n < 0) {
|
||||
// This should never happen.
|
||||
DCHECK_GE(n, 0);
|
||||
return -1;
|
||||
}
|
||||
if (port_id < n) {
|
||||
return output_arg_id;
|
||||
}
|
||||
port_id -= n;
|
||||
} else {
|
||||
--port_id;
|
||||
}
|
||||
}
|
||||
|
||||
return -1;
|
||||
}
|
||||
|
||||
GraphView::GraphView(GraphDef* graph) : graph_(graph) {
|
||||
for (int i = 0; i < graph_->node_size(); i++) {
|
||||
auto node = graph_->mutable_node(i);
|
||||
|
@ -20,11 +20,21 @@ limitations under the License.
|
||||
#include <unordered_set>
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
// Map a node/op's output port_id to arg_id.
|
||||
//
|
||||
// The port_id refers to the n-th tensor of the node, while the arg_id refers to
|
||||
// the n-th arg of the op. These two can be different if an op's arg is a list
|
||||
// of tensors.
|
||||
//
|
||||
// We return -1 for any invalid port_id (i.e., no corresponding arg_id).
|
||||
int OpOutputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id);
|
||||
|
||||
// A utility class to simplify the traversal of a GraphDef.
|
||||
class GraphView {
|
||||
public:
|
||||
|
@ -25,6 +25,60 @@ namespace {
|
||||
|
||||
class GraphViewTest : public ::testing::Test {};
|
||||
|
||||
TEST_F(GraphViewTest, OpOutputPortIdToArgIdShapeN) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10});
|
||||
ops::ShapeN b(s.WithOpName("b"), {a, a, a});
|
||||
|
||||
GraphDef graph_def;
|
||||
TF_CHECK_OK(s.ToGraphDef(&graph_def));
|
||||
GraphView graph_view(&graph_def);
|
||||
|
||||
const NodeDef& a_node_def = *graph_view.GetNode("a");
|
||||
const NodeDef& b_node_def = *graph_view.GetNode("b");
|
||||
|
||||
const OpDef* a_op_def = nullptr;
|
||||
const OpDef* b_op_def = nullptr;
|
||||
EXPECT_TRUE(
|
||||
OpRegistry::Global()->LookUpOpDef(a_node_def.op(), &a_op_def).ok());
|
||||
EXPECT_TRUE(
|
||||
OpRegistry::Global()->LookUpOpDef(b_node_def.op(), &b_op_def).ok());
|
||||
|
||||
EXPECT_EQ(0, OpOutputPortIdToArgId(b_node_def, *a_op_def, 0));
|
||||
EXPECT_EQ(-1, OpOutputPortIdToArgId(b_node_def, *a_op_def, 1));
|
||||
|
||||
EXPECT_EQ(0, OpOutputPortIdToArgId(b_node_def, *b_op_def, 0));
|
||||
EXPECT_EQ(0, OpOutputPortIdToArgId(b_node_def, *b_op_def, 1));
|
||||
EXPECT_EQ(0, OpOutputPortIdToArgId(b_node_def, *b_op_def, 2));
|
||||
EXPECT_EQ(-1, OpOutputPortIdToArgId(b_node_def, *b_op_def, 3));
|
||||
EXPECT_EQ(-1, OpOutputPortIdToArgId(b_node_def, *b_op_def, 4));
|
||||
}
|
||||
|
||||
TEST_F(GraphViewTest, OpOutputPortIdToArgIdSparseSplit) {
|
||||
for (int num_splits : {1, 2}) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output a = ops::Const<int64>(s.WithOpName("a"), 1, {10, 10});
|
||||
ops::SparseSplit b(s.WithOpName("b"), a, a, a, a, num_splits);
|
||||
|
||||
GraphDef graph_def;
|
||||
TF_CHECK_OK(s.ToGraphDef(&graph_def));
|
||||
GraphView graph_view(&graph_def);
|
||||
|
||||
const NodeDef& b_node_def = *graph_view.GetNode("b");
|
||||
const OpDef* b_op_def = nullptr;
|
||||
EXPECT_TRUE(
|
||||
OpRegistry::Global()->LookUpOpDef(b_node_def.op(), &b_op_def).ok());
|
||||
|
||||
for (int port_id = 0; port_id <= num_splits * 3; ++port_id) {
|
||||
int arg_id = -1;
|
||||
if (port_id < num_splits * 3) {
|
||||
arg_id = port_id / num_splits;
|
||||
}
|
||||
EXPECT_EQ(arg_id, OpOutputPortIdToArgId(b_node_def, *b_op_def, port_id));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(GraphViewTest, BasicGraph) {
|
||||
TrivialTestGraphInputYielder fake_input(4, 2, 2, false, {"/CPU:0", "/GPU:0"});
|
||||
GrapplerItem item;
|
||||
|
@ -518,6 +518,7 @@ cc_library(
|
||||
":loop_optimizer",
|
||||
":memory_optimizer",
|
||||
":model_pruner",
|
||||
":pin_to_host_optimizer",
|
||||
":remapper",
|
||||
":scoped_allocator_optimizer",
|
||||
":shape_optimizer",
|
||||
@ -883,3 +884,41 @@ tf_cc_test(
|
||||
"//tensorflow/core/grappler/utils:grappler_test",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "pin_to_host_optimizer",
|
||||
srcs = ["pin_to_host_optimizer.cc"],
|
||||
hdrs = [
|
||||
"pin_to_host_optimizer.h",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":graph_optimizer",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/grappler:graph_view",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
"//tensorflow/core/grappler:op_types",
|
||||
"//tensorflow/core/grappler:utils",
|
||||
"//tensorflow/core/grappler/costs:graph_properties",
|
||||
"//tensorflow/core/grappler/utils:frame",
|
||||
"//tensorflow/core/grappler/utils:symbolic_shapes",
|
||||
"//tensorflow/core/grappler/utils:topological_sort",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "pin_to_host_optimizer_test",
|
||||
srcs = ["pin_to_host_optimizer_test.cc"],
|
||||
deps = [
|
||||
":pin_to_host_optimizer",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
"//tensorflow/core/grappler/utils:grappler_test",
|
||||
],
|
||||
)
|
||||
|
@ -29,6 +29,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/grappler/optimizers/loop_optimizer.h"
|
||||
#include "tensorflow/core/grappler/optimizers/memory_optimizer.h"
|
||||
#include "tensorflow/core/grappler/optimizers/model_pruner.h"
|
||||
#include "tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h"
|
||||
#include "tensorflow/core/grappler/optimizers/remapper.h"
|
||||
#include "tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.h"
|
||||
#include "tensorflow/core/grappler/optimizers/shape_optimizer.h"
|
||||
@ -105,6 +106,7 @@ std::unique_ptr<GraphOptimizer> MetaOptimizer::MakeNewOptimizer(
|
||||
MK_OPT("scoped_allocator",
|
||||
new ScopedAllocatorOptimizer(cfg_.scoped_allocator_optimization(),
|
||||
cfg_.scoped_allocator_opts()));
|
||||
MK_OPT("small_op", new PinToHostOptimizer(cfg_.pin_to_host_optimization()));
|
||||
|
||||
return std::unique_ptr<GraphOptimizer>();
|
||||
}
|
||||
@ -133,6 +135,9 @@ Status MetaOptimizer::InitializeOptimizers(
|
||||
if (cfg_.remapping() != RewriterConfig::OFF) {
|
||||
optimizers->push_back(MakeUnique<Remapper>(cfg_.remapping()));
|
||||
}
|
||||
if (cfg_.pin_to_host_optimization() == RewriterConfig::ON) {
|
||||
optimizers->push_back(MakeUnique<PinToHostOptimizer>());
|
||||
}
|
||||
if (cfg_.arithmetic_optimization() != RewriterConfig::OFF) {
|
||||
optimizers->push_back(
|
||||
MakeUnique<ArithmeticOptimizer>(cfg_.arithmetic_optimization()));
|
||||
@ -468,6 +473,7 @@ bool MetaOptimizerEnabled(const RewriterConfig& cfg) {
|
||||
cfg.memory_optimization() != RewriterConfig::NO_MEM_OPT ||
|
||||
cfg.debug_stripper() == RewriterConfig::ON ||
|
||||
cfg.scoped_allocator_optimization() == RewriterConfig::ON ||
|
||||
cfg.pin_to_host_optimization() == RewriterConfig::ON ||
|
||||
!cfg.optimizers().empty() || !cfg.custom_optimizers().empty();
|
||||
}
|
||||
|
||||
|
218
tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc
Normal file
218
tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc
Normal file
@ -0,0 +1,218 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h"
|
||||
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/grappler/graph_view.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/op_types.h"
|
||||
#include "tensorflow/core/grappler/utils/symbolic_shapes.h"
|
||||
#include "tensorflow/core/grappler/utils/topological_sort.h"
|
||||
#include "tensorflow/core/lib/core/error_codes.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
namespace internal {
|
||||
|
||||
// TODO(williamchan): Change this constant to be something smarter, maybe
|
||||
// dynamically determined.
|
||||
constexpr int64 kTensorMaxSize = 64;
|
||||
|
||||
// Find KernelDef for `node`.
|
||||
Status TryFindKernelDef(const NodeDef& node, const KernelDef** kdef) {
|
||||
// Try find KernelDef for node.device, else GPU or CPU.
|
||||
for (const DeviceType& device :
|
||||
{node.device().c_str(), DEVICE_GPU, DEVICE_CPU}) {
|
||||
Status s = FindKernelDef(device, node, kdef, nullptr);
|
||||
if (s.ok()) {
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
return errors::NotFound("Could not find KernelDef for op: ", node.op());
|
||||
}
|
||||
|
||||
// Check if all node's inputs are pinned to CPU memory.
|
||||
bool AreAllNodeInputsPinnedToHost(const GraphView& graph, const NodeDef& node) {
|
||||
// Loop through all the inputs excluding the controlling nodes.
|
||||
for (const GraphView::OutputPort& fanin : graph.GetFanins(node, false)) {
|
||||
// Check if (the fanin) op's device is on CPU.
|
||||
if (str_util::StrContains(fanin.node->device(), DEVICE_CPU)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check if (the fanin) op's output port is pinned to HostMemory.
|
||||
const OpDef* fanin_odef = nullptr;
|
||||
Status s = OpRegistry::Global()->LookUpOpDef(fanin.node->op(), &fanin_odef);
|
||||
if (!s.ok()) {
|
||||
LOG(INFO) << "Could not find OpDef for : " << fanin.node->op();
|
||||
return false;
|
||||
}
|
||||
|
||||
const int output_arg_id =
|
||||
OpOutputPortIdToArgId(*fanin.node, *fanin_odef, fanin.port_id);
|
||||
if (output_arg_id < 0) {
|
||||
LOG(WARNING) << "Invalid port: " << fanin.port_id << "!\n"
|
||||
<< node.DebugString() << "\n"
|
||||
<< fanin_odef->DebugString();
|
||||
return false;
|
||||
}
|
||||
|
||||
const KernelDef* fanin_kdef = nullptr;
|
||||
s = TryFindKernelDef(*fanin.node, &fanin_kdef);
|
||||
if (!s.ok()) {
|
||||
LOG(INFO) << "Could not find KernelDef for : " << fanin.node->op();
|
||||
return false;
|
||||
}
|
||||
|
||||
bool fanin_pinned = false;
|
||||
for (const string& host_memory_arg : fanin_kdef->host_memory_arg()) {
|
||||
if (fanin_odef->output_arg(output_arg_id).name() == host_memory_arg) {
|
||||
fanin_pinned = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!fanin_pinned) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IsTensorIntegerAndSmall(const OpInfo::TensorProperties& prop) {
|
||||
// Check if Tensor is integer and small size.
|
||||
|
||||
// Check type to be int32 or int64.
|
||||
if (prop.dtype() != DataType::DT_INT32 &&
|
||||
prop.dtype() != DataType::DT_INT64) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check size known and small.
|
||||
const int64 size = NumCoefficients(prop.shape());
|
||||
if (size < 0 || size > kTensorMaxSize) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool AreAllNodeInputsAndOutputsIntsAndSmall(const GraphProperties& properties,
|
||||
const NodeDef& node) {
|
||||
for (const auto& prop : properties.GetInputProperties(node.name())) {
|
||||
if (!IsTensorIntegerAndSmall(prop)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto& prop : properties.GetOutputProperties(node.name())) {
|
||||
if (!IsTensorIntegerAndSmall(prop)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
string TryFindHostDevice(const gtl::FlatSet<string>& devices,
|
||||
bool has_device_cpu, const string& device) {
|
||||
// Force this node onto the CPU.
|
||||
if (device.empty() && has_device_cpu) {
|
||||
return "/device:CPU:0";
|
||||
} else if (str_util::StrContains(device, DEVICE_GPU)) {
|
||||
// Sometimes the cluster can have:
|
||||
// devices = {"/device:CPU:0", "/device:XLA_GPU:0"}
|
||||
// and we need to handle them properly.
|
||||
for (const auto& device_match :
|
||||
{std::pair<string, string>("GPU", "CPU:0"),
|
||||
std::pair<string, string>("/device", "/device:CPU:0")}) {
|
||||
const string device_host =
|
||||
strings::StrCat(device.substr(0, device.rfind(device_match.first)),
|
||||
device_match.second);
|
||||
if (devices.find(device_host) != devices.end()) {
|
||||
return device_host;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// We couldn't find an appropriate Host device, return original device.
|
||||
return device;
|
||||
}
|
||||
} // end namespace internal
|
||||
|
||||
Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
GraphDef* optimized_graph) {
|
||||
*optimized_graph = item.graph;
|
||||
|
||||
GraphProperties properties(item);
|
||||
bool has_properties = false;
|
||||
GraphView graph(optimized_graph);
|
||||
|
||||
gtl::FlatSet<string> devices;
|
||||
if (cluster) {
|
||||
const std::vector<string> device_names = cluster->GetDeviceNames();
|
||||
devices.insert(device_names.begin(), device_names.end());
|
||||
} else {
|
||||
devices = {"/device:CPU:0"};
|
||||
}
|
||||
|
||||
const bool has_device_cpu = devices.find("/device:CPU:0") != devices.end();
|
||||
|
||||
// Topologically sort the graph, so that we traverse the nodes in order. This
|
||||
// will help us discover producer->consumer chains of Host ops.
|
||||
TF_RETURN_IF_ERROR(TopologicalSort(optimized_graph));
|
||||
for (auto& node : *optimized_graph->mutable_node()) {
|
||||
// Check if node already on CPU.
|
||||
if (str_util::StrContains(node.device(), DEVICE_CPU)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check the node can be run on CPU.
|
||||
Status s = FindKernelDef(DEVICE_CPU, node, nullptr, nullptr);
|
||||
if (!s.ok()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check all input's are pinned to CPU.
|
||||
if (!internal::AreAllNodeInputsPinnedToHost(graph, node)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!has_properties) {
|
||||
// This is an expensive call, call it lazily.
|
||||
TF_RETURN_IF_ERROR(properties.InferStatically(false));
|
||||
has_properties = true;
|
||||
}
|
||||
|
||||
// Check all inputs and outputs are integers and small.
|
||||
if (!internal::AreAllNodeInputsAndOutputsIntsAndSmall(properties, node)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Try and swap the device to Host.
|
||||
node.set_device(
|
||||
internal::TryFindHostDevice(devices, has_device_cpu, node.device()));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // end namespace grappler
|
||||
} // end namespace tensorflow
|
62
tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h
Normal file
62
tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h
Normal file
@ -0,0 +1,62 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_PIN_TO_HOST_OPTIMIZER_H_
|
||||
#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_PIN_TO_HOST_OPTIMIZER_H_
|
||||
|
||||
#include <unordered_set>
|
||||
#include "tensorflow/core/grappler/costs/graph_properties.h"
|
||||
#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
#include "tensorflow/core/protobuf/rewriter_config.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
namespace internal {
|
||||
// Try and find an appropriate Host device in `devices` given `device`.
|
||||
string TryFindHostDevice(const gtl::FlatSet<string>& devices,
|
||||
bool has_device_cpu, const string& device);
|
||||
} // end namespace internal
|
||||
|
||||
// Optimize TensorFlow ops that should be swapped into the CPU to avoid
|
||||
// excessive cpu<->gpu memcpy/sync.
|
||||
//
|
||||
// TODO(williamchan): The current heuristic will swap any small integer Const to
|
||||
// CPU. This may cause a problem cpu->cpu->gpu wherein the original behaviour of
|
||||
// gpu->gpu->gpu may have been better/faster. We should probably fix this.
|
||||
class PinToHostOptimizer : public GraphOptimizer {
|
||||
public:
|
||||
PinToHostOptimizer() : opt_level_(RewriterConfig::DEFAULT) {}
|
||||
explicit PinToHostOptimizer(RewriterConfig::Toggle opt_level)
|
||||
: opt_level_(opt_level) {}
|
||||
|
||||
~PinToHostOptimizer() override {}
|
||||
|
||||
string name() const override { return "pin_to_host_optimizer"; };
|
||||
|
||||
Status Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
GraphDef* optimized_graph) override;
|
||||
|
||||
void Feedback(Cluster* cluster, const GrapplerItem& item,
|
||||
const GraphDef& optimized_graph, double result) override {}
|
||||
|
||||
private:
|
||||
RewriterConfig::Toggle opt_level_;
|
||||
};
|
||||
|
||||
} // end namespace grappler
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_PIN_TO_HOST_OPTIMIZER_H_
|
@ -0,0 +1,162 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/utils/grappler_test.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
namespace {
|
||||
|
||||
class PinToHostOptimizerTest : public GrapplerTest {};
|
||||
|
||||
TEST_F(PinToHostOptimizerTest, TryFindHostDevice) {
|
||||
gtl::FlatSet<string> devices = {};
|
||||
EXPECT_EQ("ABC", internal::TryFindHostDevice(devices, false, "ABC"));
|
||||
|
||||
devices = {"/device:CPU:0", "/device:XLA_GPU:0"};
|
||||
EXPECT_EQ(internal::TryFindHostDevice(devices, true, ""), "/device:CPU:0");
|
||||
EXPECT_EQ(internal::TryFindHostDevice(devices, true, "/device:XLA_GPU:0"),
|
||||
"/device:CPU:0");
|
||||
EXPECT_EQ(internal::TryFindHostDevice(devices, true, "/device:XLA_GPU:*"),
|
||||
"/device:CPU:0");
|
||||
|
||||
devices = {"/device:XLA_CPU:0", "/device:XLA_GPU:0"};
|
||||
EXPECT_EQ(internal::TryFindHostDevice(devices, false, ""), "");
|
||||
EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:0"),
|
||||
"/device:XLA_CPU:0");
|
||||
EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:*"),
|
||||
"/device:XLA_CPU:0");
|
||||
|
||||
devices = {"/device:XLA_GPU:0"};
|
||||
EXPECT_EQ(internal::TryFindHostDevice(devices, false, ""), "");
|
||||
EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:0"),
|
||||
"/device:XLA_GPU:0");
|
||||
EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:*"),
|
||||
"/device:XLA_GPU:*");
|
||||
}
|
||||
|
||||
TEST_F(PinToHostOptimizerTest, OptimizeSmallOpsToHost) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output a = ops::Const(s.WithOpName("a"), 1, {1024, 1024});
|
||||
Output c = ops::Shape(s.WithOpName("c"), a);
|
||||
Output d = ops::Const(s.WithOpName("d"), 0, {1});
|
||||
Output e = ops::ReduceProd(s.WithOpName("e"), c, d);
|
||||
|
||||
GrapplerItem item;
|
||||
item.fetch = {"a", "c", "d", "e"};
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
|
||||
auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
|
||||
|
||||
GraphDef output;
|
||||
PinToHostOptimizer optimizer(RewriterConfig::ON);
|
||||
TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
|
||||
auto tensors = EvaluateNodes(item.graph, item.fetch);
|
||||
EXPECT_EQ(tensors_expected.size(), tensors.size());
|
||||
for (int i = 0; i < tensors.size(); ++i) {
|
||||
test::ExpectTensorEqual<int32>(tensors[i], tensors_expected[i]);
|
||||
}
|
||||
|
||||
int found = 0;
|
||||
for (const NodeDef& node : output.node()) {
|
||||
if (node.name() == "a" || node.name() == "c") {
|
||||
EXPECT_TRUE(node.device().empty());
|
||||
} else if (node.name() == "d" || node.name() == "e") {
|
||||
EXPECT_EQ(node.device(), "/device:CPU:0");
|
||||
}
|
||||
++found;
|
||||
}
|
||||
EXPECT_EQ(found, 4);
|
||||
}
|
||||
|
||||
TEST_F(PinToHostOptimizerTest, TopologicalSort) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output a = ops::Const(s.WithOpName("a"), 1, {1024, 1024});
|
||||
Output c = ops::Shape(s.WithOpName("c"), a);
|
||||
Output d = ops::Const(s.WithOpName("d"), 0, {1});
|
||||
Output e = ops::ReduceProd(s.WithOpName("e"), c, d);
|
||||
|
||||
GrapplerItem item;
|
||||
item.fetch = {"a", "c", "d", "e"};
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
|
||||
auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
|
||||
|
||||
// Reverse the graph, and hence rely on the optimizer to sort it.
|
||||
std::reverse(item.graph.mutable_node()->begin(),
|
||||
item.graph.mutable_node()->end());
|
||||
|
||||
GraphDef output;
|
||||
PinToHostOptimizer optimizer(RewriterConfig::ON);
|
||||
TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
|
||||
auto tensors = EvaluateNodes(item.graph, item.fetch);
|
||||
EXPECT_EQ(tensors_expected.size(), tensors.size());
|
||||
for (int i = 0; i < tensors.size(); ++i) {
|
||||
test::ExpectTensorEqual<int32>(tensors[i], tensors_expected[i]);
|
||||
}
|
||||
|
||||
int found = 0;
|
||||
for (const NodeDef& node : output.node()) {
|
||||
if (node.name() == "a" || node.name() == "c") {
|
||||
EXPECT_TRUE(node.device().empty());
|
||||
} else if (node.name() == "d" || node.name() == "e") {
|
||||
EXPECT_EQ(node.device(), "/device:CPU:0");
|
||||
}
|
||||
++found;
|
||||
}
|
||||
EXPECT_EQ(found, 4);
|
||||
}
|
||||
|
||||
TEST_F(PinToHostOptimizerTest, PortIdToArgId) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output a = ops::Const(s.WithOpName("a"), 1, {1, 2, 3});
|
||||
ops::ShapeN b(s.WithOpName("b"), {a, a, a});
|
||||
|
||||
GrapplerItem item;
|
||||
item.fetch = {"a", "b"};
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
|
||||
auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
|
||||
|
||||
GraphDef output;
|
||||
PinToHostOptimizer optimizer(RewriterConfig::ON);
|
||||
TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
|
||||
auto tensors = EvaluateNodes(item.graph, item.fetch);
|
||||
EXPECT_EQ(tensors_expected.size(), tensors.size());
|
||||
for (int i = 0; i < tensors.size(); ++i) {
|
||||
test::ExpectTensorEqual<int32>(tensors[i], tensors_expected[i]);
|
||||
}
|
||||
|
||||
int found = 0;
|
||||
for (const NodeDef& node : output.node()) {
|
||||
EXPECT_EQ(node.device(), "/device:CPU:0");
|
||||
++found;
|
||||
}
|
||||
EXPECT_EQ(found, 2);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
@ -30,13 +30,16 @@ GrapplerTest::GrapplerTest() {
|
||||
// optimizations interfering in the comparison.
|
||||
RewriterConfig* cfg =
|
||||
options_.config.mutable_graph_options()->mutable_rewrite_options();
|
||||
cfg->set_constant_folding(RewriterConfig::OFF);
|
||||
// TODO(rmlarsen): Add utility to generate config w/ all optimizers turned
|
||||
// off.
|
||||
cfg->set_arithmetic_optimization(RewriterConfig::OFF);
|
||||
cfg->set_constant_folding(RewriterConfig::OFF);
|
||||
cfg->set_debug_stripper(RewriterConfig::OFF);
|
||||
cfg->set_dependency_optimization(RewriterConfig::OFF);
|
||||
cfg->set_loop_optimization(RewriterConfig::OFF);
|
||||
cfg->set_function_optimization(RewriterConfig::OFF);
|
||||
cfg->set_layout_optimizer(RewriterConfig::OFF);
|
||||
cfg->set_debug_stripper(RewriterConfig::OFF);
|
||||
cfg->set_loop_optimization(RewriterConfig::OFF);
|
||||
cfg->set_pin_to_host_optimization(RewriterConfig::OFF);
|
||||
}
|
||||
|
||||
std::vector<Tensor> GrapplerTest::EvaluateNodes(
|
||||
|
@ -75,6 +75,8 @@ message RewriterConfig {
|
||||
// Try to allocate some independent Op outputs contiguously in order to
|
||||
// merge or eliminate downstream Ops (off by default).
|
||||
Toggle scoped_allocator_optimization = 15;
|
||||
// Force small ops onto the CPU (default is OFF).
|
||||
Toggle pin_to_host_optimization = 18;
|
||||
|
||||
// Controls how many times we run the optimizers in meta optimizer (default
|
||||
// is once).
|
||||
|
@ -57,7 +57,8 @@ def no_rewrite_session_config():
|
||||
disable_model_pruning=True,
|
||||
constant_folding=rewriter_config_pb2.RewriterConfig.OFF,
|
||||
arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF,
|
||||
dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF)
|
||||
dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF,
|
||||
pin_to_host_optimization=rewriter_config_pb2.RewriterConfig.OFF)
|
||||
|
||||
graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
|
||||
return config_pb2.ConfigProto(graph_options=graph_options)
|
||||
|
@ -45,6 +45,7 @@ class ReconstructNonDebugGraphTest(test_util.TensorFlowTestCase):
|
||||
def _no_rewrite_session_config(self):
|
||||
rewriter_config = rewriter_config_pb2.RewriterConfig(
|
||||
dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF,
|
||||
pin_to_host_optimization=rewriter_config_pb2.RewriterConfig.OFF,
|
||||
min_graph_nodes=-1)
|
||||
graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
|
||||
return config_pb2.ConfigProto(graph_options=graph_options)
|
||||
@ -156,7 +157,7 @@ class ReconstructNonDebugGraphTest(test_util.TensorFlowTestCase):
|
||||
sess, cond, expected_output=21.0)
|
||||
|
||||
def testReconstructGraphWithWhileLoop(self):
|
||||
with session.Session() as sess:
|
||||
with session.Session(config=self._no_rewrite_session_config()) as sess:
|
||||
loop_body = lambda i: math_ops.add(i, 2)
|
||||
loop_cond = lambda i: math_ops.less(i, 16)
|
||||
i = constant_op.constant(10, name="i")
|
||||
|
@ -1934,6 +1934,8 @@ class TensorFlowTestCase(googletest.TestCase):
|
||||
rewriter_config_pb2.RewriterConfig.OFF)
|
||||
config.graph_options.rewrite_options.arithmetic_optimization = (
|
||||
rewriter_config_pb2.RewriterConfig.OFF)
|
||||
config.graph_options.rewrite_options.pin_to_host_optimization = (
|
||||
rewriter_config_pb2.RewriterConfig.OFF)
|
||||
return config
|
||||
|
||||
return ErrorLoggingSession(graph=graph, config=prepare_config(config))
|
||||
|
Loading…
Reference in New Issue
Block a user