diff --git a/tensorflow/core/grappler/grappler_item.h b/tensorflow/core/grappler/grappler_item.h index e0709c682b0..1fb54f60935 100644 --- a/tensorflow/core/grappler/grappler_item.h +++ b/tensorflow/core/grappler/grappler_item.h @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/variable.pb.h" #include "tensorflow/core/protobuf/queue_runner.pb.h" namespace tensorflow { diff --git a/tensorflow/core/grappler/optimizers/auto_parallel.cc b/tensorflow/core/grappler/optimizers/auto_parallel.cc index d4326a022f4..e3e75a2320c 100644 --- a/tensorflow/core/grappler/optimizers/auto_parallel.cc +++ b/tensorflow/core/grappler/optimizers/auto_parallel.cc @@ -167,6 +167,11 @@ Status AutoParallel::Initialize(const GrapplerItem& item) { for (const auto& variable : item.MainVariables()) { dont_replicate_nodes.insert(variable->name()); } + + for (const auto& init : item.init_ops) { + dont_replicate_nodes.insert(NodeName(init)); + } + // Don't replicate all input nodes, except the dequeue node. for (const auto& input_node : input_nodes) { if (input_node->name() != dequeue_node->name()) { diff --git a/tensorflow/core/grappler/optimizers/auto_parallel.h b/tensorflow/core/grappler/optimizers/auto_parallel.h index ad90bbe0289..c5d2d47782f 100644 --- a/tensorflow/core/grappler/optimizers/auto_parallel.h +++ b/tensorflow/core/grappler/optimizers/auto_parallel.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_GRAPPLER_OPTIMIZERS_AUTO_PARALLEL_H_ #include "tensorflow/core/grappler/optimizers/graph_optimizer.h" +#include "tensorflow/core/framework/variable.pb.h" #include "tensorflow/core/lib/core/status.h" namespace tensorflow { diff --git a/tensorflow/core/grappler/optimizers/auto_parallel_test.cc b/tensorflow/core/grappler/optimizers/auto_parallel_test.cc index 3d1b4a34bfc..9a41b5e0b51 100644 --- a/tensorflow/core/grappler/optimizers/auto_parallel_test.cc +++ b/tensorflow/core/grappler/optimizers/auto_parallel_test.cc @@ -33,6 +33,7 @@ TEST_F(AutoParallelTest, SimpleParallel) { Output constant_b = ops::Const(s.WithOpName("constant_b"), 1, {1}); Output var = ops::Variable(s.WithOpName("var"), {1}, DT_FLOAT); Output assign = ops::Assign(s.WithOpName("assign"), {var}, {constant_a}); + Output identity = ops::Identity(s.WithOpName("identity"), {var}); Output fifo_queue = ops::FIFOQueue(s.WithOpName("fifo_queue"), {DT_FLOAT}); auto dequeue = ops::QueueDequeueMany(s.WithOpName("dequeue"), {fifo_queue}, {constant_b}, {DT_FLOAT}); @@ -44,13 +45,14 @@ TEST_F(AutoParallelTest, SimpleParallel) { GrapplerItem item; item.init_ops.push_back("assign"); item.fetch.push_back("apply_gradient"); + item.init_ops.push_back("assign"); TF_CHECK_OK(s.ToGraphDef(&item.graph)); AutoParallel parallel(2); GraphDef output; Status status = parallel.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); - EXPECT_EQ(20, output.node_size()); + EXPECT_EQ(21, output.node_size()); const NodeDef& node_assign = output.node(0); EXPECT_EQ("assign", node_assign.name()); @@ -62,60 +64,64 @@ TEST_F(AutoParallelTest, SimpleParallel) { const NodeDef& node_fifo_queue = output.node(2); EXPECT_EQ("fifo_queue", node_fifo_queue.name()); - const NodeDef& node_var = output.node(3); + const NodeDef& node_identity = output.node(3); + EXPECT_EQ("identity", node_identity.name()); + EXPECT_EQ("var", node_identity.input(0)); + + const NodeDef& node_var = output.node(4); EXPECT_EQ("var", node_var.name()); - const NodeDef& node_div_const0 = output.node(4); + const NodeDef& node_div_const0 = output.node(5); EXPECT_EQ("AutoParallel-Replica-0/AutoParallel-Div-Const", node_div_const0.name()); - const NodeDef& node_div0 = output.node(5); + const NodeDef& node_div0 = output.node(6); EXPECT_EQ("AutoParallel-Replica-0/AutoParallel-Div-apply_gradient", node_div0.name()); - const NodeDef& node_add0 = output.node(6); + const NodeDef& node_add0 = output.node(7); EXPECT_EQ("AutoParallel-Replica-0/add", node_add0.name()); - const NodeDef& node_gradient0 = output.node(7); + const NodeDef& node_gradient0 = output.node(8); EXPECT_EQ("AutoParallel-Replica-0/apply_gradient", node_gradient0.name()); - const NodeDef& node_constant_a0 = output.node(8); + const NodeDef& node_constant_a0 = output.node(9); EXPECT_EQ("AutoParallel-Replica-0/constant_a", node_constant_a0.name()); - const NodeDef& node_dequeue0 = output.node(9); + const NodeDef& node_dequeue0 = output.node(10); EXPECT_EQ("AutoParallel-Replica-0/dequeue", node_dequeue0.name()); - const NodeDef& node_learning_rate0 = output.node(10); + const NodeDef& node_learning_rate0 = output.node(11); EXPECT_EQ("AutoParallel-Replica-0/learning_rate", node_learning_rate0.name()); - const NodeDef& node_div_const1 = output.node(11); + const NodeDef& node_div_const1 = output.node(12); EXPECT_EQ("AutoParallel-Replica-1/AutoParallel-Div-Const", node_div_const1.name()); - const NodeDef& node_div1 = output.node(12); + const NodeDef& node_div1 = output.node(13); EXPECT_EQ("AutoParallel-Replica-1/AutoParallel-Div-apply_gradient", node_div1.name()); - const NodeDef& node_add1 = output.node(13); + const NodeDef& node_add1 = output.node(14); EXPECT_EQ("AutoParallel-Replica-1/add", node_add1.name()); - const NodeDef& node_gradient1 = output.node(14); + const NodeDef& node_gradient1 = output.node(15); EXPECT_EQ("AutoParallel-Replica-1/apply_gradient", node_gradient1.name()); - const NodeDef& node_constant_a1 = output.node(15); + const NodeDef& node_constant_a1 = output.node(16); EXPECT_EQ("AutoParallel-Replica-1/constant_a", node_constant_a1.name()); - const NodeDef& node_dequeue1 = output.node(16); + const NodeDef& node_dequeue1 = output.node(17); EXPECT_EQ("AutoParallel-Replica-1/dequeue", node_dequeue1.name()); - const NodeDef& node_learning_rate1 = output.node(17); + const NodeDef& node_learning_rate1 = output.node(18); EXPECT_EQ("AutoParallel-Replica-1/learning_rate", node_learning_rate1.name()); - const NodeDef& node_fetch = output.node(18); + const NodeDef& node_fetch = output.node(19); EXPECT_EQ("AutoParallel-Control-Fetch", node_fetch.name()); EXPECT_EQ("^AutoParallel-Replica-0/apply_gradient", node_fetch.input(0)); EXPECT_EQ("^AutoParallel-Replica-1/apply_gradient", node_fetch.input(1)); - const NodeDef& node_gradient = output.node(19); + const NodeDef& node_gradient = output.node(20); EXPECT_EQ("apply_gradient", node_gradient.name()); EXPECT_EQ("^AutoParallel-Control-Fetch", node_gradient.input(0)); }