Fixes of AutoParallel bug (#10368)

* Fix the bug that auto_parallel could replicate variable snapshot name

* Use NodeName in grappler:utils instead of substr, convert variables->variable_def of grappler item

* remove variable_def from grappler item, exclude snapshot nodes from dont_replicate_nodes in auto_parallel
This commit is contained in:
sj6077 2017-06-22 04:08:14 +09:00 committed by Benoit Steiner
parent b7acb6abe0
commit b58d983533
4 changed files with 31 additions and 18 deletions

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/variable.pb.h"
#include "tensorflow/core/protobuf/queue_runner.pb.h" #include "tensorflow/core/protobuf/queue_runner.pb.h"
namespace tensorflow { namespace tensorflow {

View File

@ -167,6 +167,11 @@ Status AutoParallel::Initialize(const GrapplerItem& item) {
for (const auto& variable : item.MainVariables()) { for (const auto& variable : item.MainVariables()) {
dont_replicate_nodes.insert(variable->name()); dont_replicate_nodes.insert(variable->name());
} }
for (const auto& init : item.init_ops) {
dont_replicate_nodes.insert(NodeName(init));
}
// Don't replicate all input nodes, except the dequeue node. // Don't replicate all input nodes, except the dequeue node.
for (const auto& input_node : input_nodes) { for (const auto& input_node : input_nodes) {
if (input_node->name() != dequeue_node->name()) { if (input_node->name() != dequeue_node->name()) {

View File

@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_GRAPPLER_OPTIMIZERS_AUTO_PARALLEL_H_ #define TENSORFLOW_GRAPPLER_OPTIMIZERS_AUTO_PARALLEL_H_
#include "tensorflow/core/grappler/optimizers/graph_optimizer.h" #include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
#include "tensorflow/core/framework/variable.pb.h"
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status.h"
namespace tensorflow { namespace tensorflow {

View File

@ -33,6 +33,7 @@ TEST_F(AutoParallelTest, SimpleParallel) {
Output constant_b = ops::Const(s.WithOpName("constant_b"), 1, {1}); Output constant_b = ops::Const(s.WithOpName("constant_b"), 1, {1});
Output var = ops::Variable(s.WithOpName("var"), {1}, DT_FLOAT); Output var = ops::Variable(s.WithOpName("var"), {1}, DT_FLOAT);
Output assign = ops::Assign(s.WithOpName("assign"), {var}, {constant_a}); Output assign = ops::Assign(s.WithOpName("assign"), {var}, {constant_a});
Output identity = ops::Identity(s.WithOpName("identity"), {var});
Output fifo_queue = ops::FIFOQueue(s.WithOpName("fifo_queue"), {DT_FLOAT}); Output fifo_queue = ops::FIFOQueue(s.WithOpName("fifo_queue"), {DT_FLOAT});
auto dequeue = ops::QueueDequeueMany(s.WithOpName("dequeue"), {fifo_queue}, auto dequeue = ops::QueueDequeueMany(s.WithOpName("dequeue"), {fifo_queue},
{constant_b}, {DT_FLOAT}); {constant_b}, {DT_FLOAT});
@ -44,13 +45,14 @@ TEST_F(AutoParallelTest, SimpleParallel) {
GrapplerItem item; GrapplerItem item;
item.init_ops.push_back("assign"); item.init_ops.push_back("assign");
item.fetch.push_back("apply_gradient"); item.fetch.push_back("apply_gradient");
item.init_ops.push_back("assign");
TF_CHECK_OK(s.ToGraphDef(&item.graph)); TF_CHECK_OK(s.ToGraphDef(&item.graph));
AutoParallel parallel(2); AutoParallel parallel(2);
GraphDef output; GraphDef output;
Status status = parallel.Optimize(nullptr, item, &output); Status status = parallel.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status); TF_EXPECT_OK(status);
EXPECT_EQ(20, output.node_size()); EXPECT_EQ(21, output.node_size());
const NodeDef& node_assign = output.node(0); const NodeDef& node_assign = output.node(0);
EXPECT_EQ("assign", node_assign.name()); EXPECT_EQ("assign", node_assign.name());
@ -62,60 +64,64 @@ TEST_F(AutoParallelTest, SimpleParallel) {
const NodeDef& node_fifo_queue = output.node(2); const NodeDef& node_fifo_queue = output.node(2);
EXPECT_EQ("fifo_queue", node_fifo_queue.name()); EXPECT_EQ("fifo_queue", node_fifo_queue.name());
const NodeDef& node_var = output.node(3); const NodeDef& node_identity = output.node(3);
EXPECT_EQ("identity", node_identity.name());
EXPECT_EQ("var", node_identity.input(0));
const NodeDef& node_var = output.node(4);
EXPECT_EQ("var", node_var.name()); EXPECT_EQ("var", node_var.name());
const NodeDef& node_div_const0 = output.node(4); const NodeDef& node_div_const0 = output.node(5);
EXPECT_EQ("AutoParallel-Replica-0/AutoParallel-Div-Const", EXPECT_EQ("AutoParallel-Replica-0/AutoParallel-Div-Const",
node_div_const0.name()); node_div_const0.name());
const NodeDef& node_div0 = output.node(5); const NodeDef& node_div0 = output.node(6);
EXPECT_EQ("AutoParallel-Replica-0/AutoParallel-Div-apply_gradient", EXPECT_EQ("AutoParallel-Replica-0/AutoParallel-Div-apply_gradient",
node_div0.name()); node_div0.name());
const NodeDef& node_add0 = output.node(6); const NodeDef& node_add0 = output.node(7);
EXPECT_EQ("AutoParallel-Replica-0/add", node_add0.name()); EXPECT_EQ("AutoParallel-Replica-0/add", node_add0.name());
const NodeDef& node_gradient0 = output.node(7); const NodeDef& node_gradient0 = output.node(8);
EXPECT_EQ("AutoParallel-Replica-0/apply_gradient", node_gradient0.name()); EXPECT_EQ("AutoParallel-Replica-0/apply_gradient", node_gradient0.name());
const NodeDef& node_constant_a0 = output.node(8); const NodeDef& node_constant_a0 = output.node(9);
EXPECT_EQ("AutoParallel-Replica-0/constant_a", node_constant_a0.name()); EXPECT_EQ("AutoParallel-Replica-0/constant_a", node_constant_a0.name());
const NodeDef& node_dequeue0 = output.node(9); const NodeDef& node_dequeue0 = output.node(10);
EXPECT_EQ("AutoParallel-Replica-0/dequeue", node_dequeue0.name()); EXPECT_EQ("AutoParallel-Replica-0/dequeue", node_dequeue0.name());
const NodeDef& node_learning_rate0 = output.node(10); const NodeDef& node_learning_rate0 = output.node(11);
EXPECT_EQ("AutoParallel-Replica-0/learning_rate", node_learning_rate0.name()); EXPECT_EQ("AutoParallel-Replica-0/learning_rate", node_learning_rate0.name());
const NodeDef& node_div_const1 = output.node(11); const NodeDef& node_div_const1 = output.node(12);
EXPECT_EQ("AutoParallel-Replica-1/AutoParallel-Div-Const", EXPECT_EQ("AutoParallel-Replica-1/AutoParallel-Div-Const",
node_div_const1.name()); node_div_const1.name());
const NodeDef& node_div1 = output.node(12); const NodeDef& node_div1 = output.node(13);
EXPECT_EQ("AutoParallel-Replica-1/AutoParallel-Div-apply_gradient", EXPECT_EQ("AutoParallel-Replica-1/AutoParallel-Div-apply_gradient",
node_div1.name()); node_div1.name());
const NodeDef& node_add1 = output.node(13); const NodeDef& node_add1 = output.node(14);
EXPECT_EQ("AutoParallel-Replica-1/add", node_add1.name()); EXPECT_EQ("AutoParallel-Replica-1/add", node_add1.name());
const NodeDef& node_gradient1 = output.node(14); const NodeDef& node_gradient1 = output.node(15);
EXPECT_EQ("AutoParallel-Replica-1/apply_gradient", node_gradient1.name()); EXPECT_EQ("AutoParallel-Replica-1/apply_gradient", node_gradient1.name());
const NodeDef& node_constant_a1 = output.node(15); const NodeDef& node_constant_a1 = output.node(16);
EXPECT_EQ("AutoParallel-Replica-1/constant_a", node_constant_a1.name()); EXPECT_EQ("AutoParallel-Replica-1/constant_a", node_constant_a1.name());
const NodeDef& node_dequeue1 = output.node(16); const NodeDef& node_dequeue1 = output.node(17);
EXPECT_EQ("AutoParallel-Replica-1/dequeue", node_dequeue1.name()); EXPECT_EQ("AutoParallel-Replica-1/dequeue", node_dequeue1.name());
const NodeDef& node_learning_rate1 = output.node(17); const NodeDef& node_learning_rate1 = output.node(18);
EXPECT_EQ("AutoParallel-Replica-1/learning_rate", node_learning_rate1.name()); EXPECT_EQ("AutoParallel-Replica-1/learning_rate", node_learning_rate1.name());
const NodeDef& node_fetch = output.node(18); const NodeDef& node_fetch = output.node(19);
EXPECT_EQ("AutoParallel-Control-Fetch", node_fetch.name()); EXPECT_EQ("AutoParallel-Control-Fetch", node_fetch.name());
EXPECT_EQ("^AutoParallel-Replica-0/apply_gradient", node_fetch.input(0)); EXPECT_EQ("^AutoParallel-Replica-0/apply_gradient", node_fetch.input(0));
EXPECT_EQ("^AutoParallel-Replica-1/apply_gradient", node_fetch.input(1)); EXPECT_EQ("^AutoParallel-Replica-1/apply_gradient", node_fetch.input(1));
const NodeDef& node_gradient = output.node(19); const NodeDef& node_gradient = output.node(20);
EXPECT_EQ("apply_gradient", node_gradient.name()); EXPECT_EQ("apply_gradient", node_gradient.name());
EXPECT_EQ("^AutoParallel-Control-Fetch", node_gradient.input(0)); EXPECT_EQ("^AutoParallel-Control-Fetch", node_gradient.input(0));
} }