Add PinToHostOptimizer to grappler: force small ops to happen on CPU (instead of

GPU). This avoids many unnecessary CPU<->GPU memcpy and syncs.

PiperOrigin-RevId: 214108484
This commit is contained in:
A. Unique TensorFlower 2018-09-22 05:15:18 -07:00 committed by TensorFlower Gardener
parent e317152dad
commit ca552d54ac
15 changed files with 599 additions and 5 deletions

View File

@ -77,6 +77,9 @@ TEST(DirectSessionWithTrackingAllocTest, CostModelTest) {
options.config.mutable_graph_options()
->mutable_rewrite_options()
->set_min_graph_nodes(-1);
options.config.mutable_graph_options()
->mutable_rewrite_options()
->set_pin_to_host_optimization(RewriterConfig::OFF);
std::unique_ptr<Session> session(NewSession(options));
TF_ASSERT_OK(session->Create(def));
std::vector<std::pair<string, Tensor>> inputs;

View File

@ -83,6 +83,7 @@ void Cluster::DisableOptimizer(bool disable) {
rewriter_config->set_memory_optimization(RewriterConfig::NO_MEM_OPT);
rewriter_config->set_shape_optimization(RewriterConfig::OFF);
rewriter_config->set_remapping(RewriterConfig::OFF);
rewriter_config->set_pin_to_host_optimization(RewriterConfig::OFF);
rewriter_config->mutable_auto_parallel()->set_enable(false);
rewriter_config->clear_optimizers();
} else {

View File

@ -14,11 +14,41 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/graph_view.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/grappler/utils.h"
namespace tensorflow {
namespace grappler {
int OpOutputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id) {
for (int output_arg_id = 0; output_arg_id < op.output_arg_size();
++output_arg_id) {
if (port_id < 0) {
return -1;
} else if (port_id == 0) {
return output_arg_id;
}
const auto& output_arg = op.output_arg(output_arg_id);
if (!output_arg.number_attr().empty()) {
const int n = node.attr().at(output_arg.number_attr()).i();
if (n < 0) {
// This should never happen.
DCHECK_GE(n, 0);
return -1;
}
if (port_id < n) {
return output_arg_id;
}
port_id -= n;
} else {
--port_id;
}
}
return -1;
}
GraphView::GraphView(GraphDef* graph) : graph_(graph) {
for (int i = 0; i < graph_->node_size(); i++) {
auto node = graph_->mutable_node(i);

View File

@ -20,11 +20,21 @@ limitations under the License.
#include <unordered_set>
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace grappler {
// Map a node/op's output port_id to arg_id.
//
// The port_id refers to the n-th tensor of the node, while the arg_id refers to
// the n-th arg of the op. These two can be different if an op's arg is a list
// of tensors.
//
// We return -1 for any invalid port_id (i.e., no corresponding arg_id).
int OpOutputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id);
// A utility class to simplify the traversal of a GraphDef.
class GraphView {
public:

View File

@ -25,6 +25,60 @@ namespace {
class GraphViewTest : public ::testing::Test {};
TEST_F(GraphViewTest, OpOutputPortIdToArgIdShapeN) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10});
ops::ShapeN b(s.WithOpName("b"), {a, a, a});
GraphDef graph_def;
TF_CHECK_OK(s.ToGraphDef(&graph_def));
GraphView graph_view(&graph_def);
const NodeDef& a_node_def = *graph_view.GetNode("a");
const NodeDef& b_node_def = *graph_view.GetNode("b");
const OpDef* a_op_def = nullptr;
const OpDef* b_op_def = nullptr;
EXPECT_TRUE(
OpRegistry::Global()->LookUpOpDef(a_node_def.op(), &a_op_def).ok());
EXPECT_TRUE(
OpRegistry::Global()->LookUpOpDef(b_node_def.op(), &b_op_def).ok());
EXPECT_EQ(0, OpOutputPortIdToArgId(b_node_def, *a_op_def, 0));
EXPECT_EQ(-1, OpOutputPortIdToArgId(b_node_def, *a_op_def, 1));
EXPECT_EQ(0, OpOutputPortIdToArgId(b_node_def, *b_op_def, 0));
EXPECT_EQ(0, OpOutputPortIdToArgId(b_node_def, *b_op_def, 1));
EXPECT_EQ(0, OpOutputPortIdToArgId(b_node_def, *b_op_def, 2));
EXPECT_EQ(-1, OpOutputPortIdToArgId(b_node_def, *b_op_def, 3));
EXPECT_EQ(-1, OpOutputPortIdToArgId(b_node_def, *b_op_def, 4));
}
TEST_F(GraphViewTest, OpOutputPortIdToArgIdSparseSplit) {
for (int num_splits : {1, 2}) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output a = ops::Const<int64>(s.WithOpName("a"), 1, {10, 10});
ops::SparseSplit b(s.WithOpName("b"), a, a, a, a, num_splits);
GraphDef graph_def;
TF_CHECK_OK(s.ToGraphDef(&graph_def));
GraphView graph_view(&graph_def);
const NodeDef& b_node_def = *graph_view.GetNode("b");
const OpDef* b_op_def = nullptr;
EXPECT_TRUE(
OpRegistry::Global()->LookUpOpDef(b_node_def.op(), &b_op_def).ok());
for (int port_id = 0; port_id <= num_splits * 3; ++port_id) {
int arg_id = -1;
if (port_id < num_splits * 3) {
arg_id = port_id / num_splits;
}
EXPECT_EQ(arg_id, OpOutputPortIdToArgId(b_node_def, *b_op_def, port_id));
}
}
}
TEST_F(GraphViewTest, BasicGraph) {
TrivialTestGraphInputYielder fake_input(4, 2, 2, false, {"/CPU:0", "/GPU:0"});
GrapplerItem item;

View File

@ -518,6 +518,7 @@ cc_library(
":loop_optimizer",
":memory_optimizer",
":model_pruner",
":pin_to_host_optimizer",
":remapper",
":scoped_allocator_optimizer",
":shape_optimizer",
@ -883,3 +884,41 @@ tf_cc_test(
"//tensorflow/core/grappler/utils:grappler_test",
],
)
cc_library(
name = "pin_to_host_optimizer",
srcs = ["pin_to_host_optimizer.cc"],
hdrs = [
"pin_to_host_optimizer.h",
],
visibility = ["//visibility:public"],
deps = [
":graph_optimizer",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:graph_view",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/costs:graph_properties",
"//tensorflow/core/grappler/utils:frame",
"//tensorflow/core/grappler/utils:symbolic_shapes",
"//tensorflow/core/grappler/utils:topological_sort",
],
)
tf_cuda_cc_test(
name = "pin_to_host_optimizer_test",
srcs = ["pin_to_host_optimizer_test.cc"],
deps = [
":pin_to_host_optimizer",
"//tensorflow/cc:cc_ops",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler/utils:grappler_test",
],
)

View File

@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/loop_optimizer.h"
#include "tensorflow/core/grappler/optimizers/memory_optimizer.h"
#include "tensorflow/core/grappler/optimizers/model_pruner.h"
#include "tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h"
#include "tensorflow/core/grappler/optimizers/remapper.h"
#include "tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.h"
#include "tensorflow/core/grappler/optimizers/shape_optimizer.h"
@ -105,6 +106,7 @@ std::unique_ptr<GraphOptimizer> MetaOptimizer::MakeNewOptimizer(
MK_OPT("scoped_allocator",
new ScopedAllocatorOptimizer(cfg_.scoped_allocator_optimization(),
cfg_.scoped_allocator_opts()));
MK_OPT("small_op", new PinToHostOptimizer(cfg_.pin_to_host_optimization()));
return std::unique_ptr<GraphOptimizer>();
}
@ -133,6 +135,9 @@ Status MetaOptimizer::InitializeOptimizers(
if (cfg_.remapping() != RewriterConfig::OFF) {
optimizers->push_back(MakeUnique<Remapper>(cfg_.remapping()));
}
if (cfg_.pin_to_host_optimization() == RewriterConfig::ON) {
optimizers->push_back(MakeUnique<PinToHostOptimizer>());
}
if (cfg_.arithmetic_optimization() != RewriterConfig::OFF) {
optimizers->push_back(
MakeUnique<ArithmeticOptimizer>(cfg_.arithmetic_optimization()));
@ -468,6 +473,7 @@ bool MetaOptimizerEnabled(const RewriterConfig& cfg) {
cfg.memory_optimization() != RewriterConfig::NO_MEM_OPT ||
cfg.debug_stripper() == RewriterConfig::ON ||
cfg.scoped_allocator_optimization() == RewriterConfig::ON ||
cfg.pin_to_host_optimization() == RewriterConfig::ON ||
!cfg.optimizers().empty() || !cfg.custom_optimizers().empty();
}

View File

@ -0,0 +1,218 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/grappler/graph_view.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/utils/symbolic_shapes.h"
#include "tensorflow/core/grappler/utils/topological_sort.h"
#include "tensorflow/core/lib/core/error_codes.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/str_util.h"
namespace tensorflow {
namespace grappler {
namespace internal {
// TODO(williamchan): Change this constant to be something smarter, maybe
// dynamically determined.
constexpr int64 kTensorMaxSize = 64;
// Find KernelDef for `node`.
Status TryFindKernelDef(const NodeDef& node, const KernelDef** kdef) {
// Try find KernelDef for node.device, else GPU or CPU.
for (const DeviceType& device :
{node.device().c_str(), DEVICE_GPU, DEVICE_CPU}) {
Status s = FindKernelDef(device, node, kdef, nullptr);
if (s.ok()) {
return Status::OK();
}
}
return errors::NotFound("Could not find KernelDef for op: ", node.op());
}
// Check if all node's inputs are pinned to CPU memory.
bool AreAllNodeInputsPinnedToHost(const GraphView& graph, const NodeDef& node) {
// Loop through all the inputs excluding the controlling nodes.
for (const GraphView::OutputPort& fanin : graph.GetFanins(node, false)) {
// Check if (the fanin) op's device is on CPU.
if (str_util::StrContains(fanin.node->device(), DEVICE_CPU)) {
continue;
}
// Check if (the fanin) op's output port is pinned to HostMemory.
const OpDef* fanin_odef = nullptr;
Status s = OpRegistry::Global()->LookUpOpDef(fanin.node->op(), &fanin_odef);
if (!s.ok()) {
LOG(INFO) << "Could not find OpDef for : " << fanin.node->op();
return false;
}
const int output_arg_id =
OpOutputPortIdToArgId(*fanin.node, *fanin_odef, fanin.port_id);
if (output_arg_id < 0) {
LOG(WARNING) << "Invalid port: " << fanin.port_id << "!\n"
<< node.DebugString() << "\n"
<< fanin_odef->DebugString();
return false;
}
const KernelDef* fanin_kdef = nullptr;
s = TryFindKernelDef(*fanin.node, &fanin_kdef);
if (!s.ok()) {
LOG(INFO) << "Could not find KernelDef for : " << fanin.node->op();
return false;
}
bool fanin_pinned = false;
for (const string& host_memory_arg : fanin_kdef->host_memory_arg()) {
if (fanin_odef->output_arg(output_arg_id).name() == host_memory_arg) {
fanin_pinned = true;
break;
}
}
if (!fanin_pinned) {
return false;
}
}
return true;
}
bool IsTensorIntegerAndSmall(const OpInfo::TensorProperties& prop) {
// Check if Tensor is integer and small size.
// Check type to be int32 or int64.
if (prop.dtype() != DataType::DT_INT32 &&
prop.dtype() != DataType::DT_INT64) {
return false;
}
// Check size known and small.
const int64 size = NumCoefficients(prop.shape());
if (size < 0 || size > kTensorMaxSize) {
return false;
}
return true;
}
bool AreAllNodeInputsAndOutputsIntsAndSmall(const GraphProperties& properties,
const NodeDef& node) {
for (const auto& prop : properties.GetInputProperties(node.name())) {
if (!IsTensorIntegerAndSmall(prop)) {
return false;
}
}
for (const auto& prop : properties.GetOutputProperties(node.name())) {
if (!IsTensorIntegerAndSmall(prop)) {
return false;
}
}
return true;
}
string TryFindHostDevice(const gtl::FlatSet<string>& devices,
bool has_device_cpu, const string& device) {
// Force this node onto the CPU.
if (device.empty() && has_device_cpu) {
return "/device:CPU:0";
} else if (str_util::StrContains(device, DEVICE_GPU)) {
// Sometimes the cluster can have:
// devices = {"/device:CPU:0", "/device:XLA_GPU:0"}
// and we need to handle them properly.
for (const auto& device_match :
{std::pair<string, string>("GPU", "CPU:0"),
std::pair<string, string>("/device", "/device:CPU:0")}) {
const string device_host =
strings::StrCat(device.substr(0, device.rfind(device_match.first)),
device_match.second);
if (devices.find(device_host) != devices.end()) {
return device_host;
}
}
}
// We couldn't find an appropriate Host device, return original device.
return device;
}
} // end namespace internal
Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* optimized_graph) {
*optimized_graph = item.graph;
GraphProperties properties(item);
bool has_properties = false;
GraphView graph(optimized_graph);
gtl::FlatSet<string> devices;
if (cluster) {
const std::vector<string> device_names = cluster->GetDeviceNames();
devices.insert(device_names.begin(), device_names.end());
} else {
devices = {"/device:CPU:0"};
}
const bool has_device_cpu = devices.find("/device:CPU:0") != devices.end();
// Topologically sort the graph, so that we traverse the nodes in order. This
// will help us discover producer->consumer chains of Host ops.
TF_RETURN_IF_ERROR(TopologicalSort(optimized_graph));
for (auto& node : *optimized_graph->mutable_node()) {
// Check if node already on CPU.
if (str_util::StrContains(node.device(), DEVICE_CPU)) {
continue;
}
// Check the node can be run on CPU.
Status s = FindKernelDef(DEVICE_CPU, node, nullptr, nullptr);
if (!s.ok()) {
continue;
}
// Check all input's are pinned to CPU.
if (!internal::AreAllNodeInputsPinnedToHost(graph, node)) {
continue;
}
if (!has_properties) {
// This is an expensive call, call it lazily.
TF_RETURN_IF_ERROR(properties.InferStatically(false));
has_properties = true;
}
// Check all inputs and outputs are integers and small.
if (!internal::AreAllNodeInputsAndOutputsIntsAndSmall(properties, node)) {
continue;
}
// Try and swap the device to Host.
node.set_device(
internal::TryFindHostDevice(devices, has_device_cpu, node.device()));
}
return Status::OK();
}
} // end namespace grappler
} // end namespace tensorflow

View File

@ -0,0 +1,62 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_PIN_TO_HOST_OPTIMIZER_H_
#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_PIN_TO_HOST_OPTIMIZER_H_
#include <unordered_set>
#include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/protobuf/rewriter_config.pb.h"
namespace tensorflow {
namespace grappler {
namespace internal {
// Try and find an appropriate Host device in `devices` given `device`.
string TryFindHostDevice(const gtl::FlatSet<string>& devices,
bool has_device_cpu, const string& device);
} // end namespace internal
// Optimize TensorFlow ops that should be swapped into the CPU to avoid
// excessive cpu<->gpu memcpy/sync.
//
// TODO(williamchan): The current heuristic will swap any small integer Const to
// CPU. This may cause a problem cpu->cpu->gpu wherein the original behaviour of
// gpu->gpu->gpu may have been better/faster. We should probably fix this.
class PinToHostOptimizer : public GraphOptimizer {
public:
PinToHostOptimizer() : opt_level_(RewriterConfig::DEFAULT) {}
explicit PinToHostOptimizer(RewriterConfig::Toggle opt_level)
: opt_level_(opt_level) {}
~PinToHostOptimizer() override {}
string name() const override { return "pin_to_host_optimizer"; };
Status Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* optimized_graph) override;
void Feedback(Cluster* cluster, const GrapplerItem& item,
const GraphDef& optimized_graph, double result) override {}
private:
RewriterConfig::Toggle opt_level_;
};
} // end namespace grappler
} // end namespace tensorflow
#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_PIN_TO_HOST_OPTIMIZER_H_

View File

@ -0,0 +1,162 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/utils/grappler_test.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace grappler {
namespace {
class PinToHostOptimizerTest : public GrapplerTest {};
TEST_F(PinToHostOptimizerTest, TryFindHostDevice) {
gtl::FlatSet<string> devices = {};
EXPECT_EQ("ABC", internal::TryFindHostDevice(devices, false, "ABC"));
devices = {"/device:CPU:0", "/device:XLA_GPU:0"};
EXPECT_EQ(internal::TryFindHostDevice(devices, true, ""), "/device:CPU:0");
EXPECT_EQ(internal::TryFindHostDevice(devices, true, "/device:XLA_GPU:0"),
"/device:CPU:0");
EXPECT_EQ(internal::TryFindHostDevice(devices, true, "/device:XLA_GPU:*"),
"/device:CPU:0");
devices = {"/device:XLA_CPU:0", "/device:XLA_GPU:0"};
EXPECT_EQ(internal::TryFindHostDevice(devices, false, ""), "");
EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:0"),
"/device:XLA_CPU:0");
EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:*"),
"/device:XLA_CPU:0");
devices = {"/device:XLA_GPU:0"};
EXPECT_EQ(internal::TryFindHostDevice(devices, false, ""), "");
EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:0"),
"/device:XLA_GPU:0");
EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:*"),
"/device:XLA_GPU:*");
}
TEST_F(PinToHostOptimizerTest, OptimizeSmallOpsToHost) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output a = ops::Const(s.WithOpName("a"), 1, {1024, 1024});
Output c = ops::Shape(s.WithOpName("c"), a);
Output d = ops::Const(s.WithOpName("d"), 0, {1});
Output e = ops::ReduceProd(s.WithOpName("e"), c, d);
GrapplerItem item;
item.fetch = {"a", "c", "d", "e"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
GraphDef output;
PinToHostOptimizer optimizer(RewriterConfig::ON);
TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
auto tensors = EvaluateNodes(item.graph, item.fetch);
EXPECT_EQ(tensors_expected.size(), tensors.size());
for (int i = 0; i < tensors.size(); ++i) {
test::ExpectTensorEqual<int32>(tensors[i], tensors_expected[i]);
}
int found = 0;
for (const NodeDef& node : output.node()) {
if (node.name() == "a" || node.name() == "c") {
EXPECT_TRUE(node.device().empty());
} else if (node.name() == "d" || node.name() == "e") {
EXPECT_EQ(node.device(), "/device:CPU:0");
}
++found;
}
EXPECT_EQ(found, 4);
}
TEST_F(PinToHostOptimizerTest, TopologicalSort) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output a = ops::Const(s.WithOpName("a"), 1, {1024, 1024});
Output c = ops::Shape(s.WithOpName("c"), a);
Output d = ops::Const(s.WithOpName("d"), 0, {1});
Output e = ops::ReduceProd(s.WithOpName("e"), c, d);
GrapplerItem item;
item.fetch = {"a", "c", "d", "e"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
// Reverse the graph, and hence rely on the optimizer to sort it.
std::reverse(item.graph.mutable_node()->begin(),
item.graph.mutable_node()->end());
GraphDef output;
PinToHostOptimizer optimizer(RewriterConfig::ON);
TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
auto tensors = EvaluateNodes(item.graph, item.fetch);
EXPECT_EQ(tensors_expected.size(), tensors.size());
for (int i = 0; i < tensors.size(); ++i) {
test::ExpectTensorEqual<int32>(tensors[i], tensors_expected[i]);
}
int found = 0;
for (const NodeDef& node : output.node()) {
if (node.name() == "a" || node.name() == "c") {
EXPECT_TRUE(node.device().empty());
} else if (node.name() == "d" || node.name() == "e") {
EXPECT_EQ(node.device(), "/device:CPU:0");
}
++found;
}
EXPECT_EQ(found, 4);
}
TEST_F(PinToHostOptimizerTest, PortIdToArgId) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output a = ops::Const(s.WithOpName("a"), 1, {1, 2, 3});
ops::ShapeN b(s.WithOpName("b"), {a, a, a});
GrapplerItem item;
item.fetch = {"a", "b"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
GraphDef output;
PinToHostOptimizer optimizer(RewriterConfig::ON);
TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
auto tensors = EvaluateNodes(item.graph, item.fetch);
EXPECT_EQ(tensors_expected.size(), tensors.size());
for (int i = 0; i < tensors.size(); ++i) {
test::ExpectTensorEqual<int32>(tensors[i], tensors_expected[i]);
}
int found = 0;
for (const NodeDef& node : output.node()) {
EXPECT_EQ(node.device(), "/device:CPU:0");
++found;
}
EXPECT_EQ(found, 2);
}
} // namespace
} // namespace grappler
} // namespace tensorflow

View File

@ -30,13 +30,16 @@ GrapplerTest::GrapplerTest() {
// optimizations interfering in the comparison.
RewriterConfig* cfg =
options_.config.mutable_graph_options()->mutable_rewrite_options();
cfg->set_constant_folding(RewriterConfig::OFF);
// TODO(rmlarsen): Add utility to generate config w/ all optimizers turned
// off.
cfg->set_arithmetic_optimization(RewriterConfig::OFF);
cfg->set_constant_folding(RewriterConfig::OFF);
cfg->set_debug_stripper(RewriterConfig::OFF);
cfg->set_dependency_optimization(RewriterConfig::OFF);
cfg->set_loop_optimization(RewriterConfig::OFF);
cfg->set_function_optimization(RewriterConfig::OFF);
cfg->set_layout_optimizer(RewriterConfig::OFF);
cfg->set_debug_stripper(RewriterConfig::OFF);
cfg->set_loop_optimization(RewriterConfig::OFF);
cfg->set_pin_to_host_optimization(RewriterConfig::OFF);
}
std::vector<Tensor> GrapplerTest::EvaluateNodes(

View File

@ -75,6 +75,8 @@ message RewriterConfig {
// Try to allocate some independent Op outputs contiguously in order to
// merge or eliminate downstream Ops (off by default).
Toggle scoped_allocator_optimization = 15;
// Force small ops onto the CPU (default is OFF).
Toggle pin_to_host_optimization = 18;
// Controls how many times we run the optimizers in meta optimizer (default
// is once).

View File

@ -57,7 +57,8 @@ def no_rewrite_session_config():
disable_model_pruning=True,
constant_folding=rewriter_config_pb2.RewriterConfig.OFF,
arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF,
dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF)
dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF,
pin_to_host_optimization=rewriter_config_pb2.RewriterConfig.OFF)
graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
return config_pb2.ConfigProto(graph_options=graph_options)

View File

@ -45,6 +45,7 @@ class ReconstructNonDebugGraphTest(test_util.TensorFlowTestCase):
def _no_rewrite_session_config(self):
rewriter_config = rewriter_config_pb2.RewriterConfig(
dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF,
pin_to_host_optimization=rewriter_config_pb2.RewriterConfig.OFF,
min_graph_nodes=-1)
graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
return config_pb2.ConfigProto(graph_options=graph_options)
@ -156,7 +157,7 @@ class ReconstructNonDebugGraphTest(test_util.TensorFlowTestCase):
sess, cond, expected_output=21.0)
def testReconstructGraphWithWhileLoop(self):
with session.Session() as sess:
with session.Session(config=self._no_rewrite_session_config()) as sess:
loop_body = lambda i: math_ops.add(i, 2)
loop_cond = lambda i: math_ops.less(i, 16)
i = constant_op.constant(10, name="i")

View File

@ -1934,6 +1934,8 @@ class TensorFlowTestCase(googletest.TestCase):
rewriter_config_pb2.RewriterConfig.OFF)
config.graph_options.rewrite_options.arithmetic_optimization = (
rewriter_config_pb2.RewriterConfig.OFF)
config.graph_options.rewrite_options.pin_to_host_optimization = (
rewriter_config_pb2.RewriterConfig.OFF)
return config
return ErrorLoggingSession(graph=graph, config=prepare_config(config))