Fixes of AutoParallel bug (#10368)
* Fix the bug that auto_parallel could replicate variable snapshot name * Use NodeName in grappler:utils instead of substr, convert variables->variable_def of grappler item * remove variable_def from grappler item, exclude snapshot nodes from dont_replicate_nodes in auto_parallel
This commit is contained in:
parent
b7acb6abe0
commit
b58d983533
@ -23,6 +23,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/core/framework/graph.pb.h"
|
#include "tensorflow/core/framework/graph.pb.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/framework/variable.pb.h"
|
||||||
#include "tensorflow/core/protobuf/queue_runner.pb.h"
|
#include "tensorflow/core/protobuf/queue_runner.pb.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
@ -167,6 +167,11 @@ Status AutoParallel::Initialize(const GrapplerItem& item) {
|
|||||||
for (const auto& variable : item.MainVariables()) {
|
for (const auto& variable : item.MainVariables()) {
|
||||||
dont_replicate_nodes.insert(variable->name());
|
dont_replicate_nodes.insert(variable->name());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for (const auto& init : item.init_ops) {
|
||||||
|
dont_replicate_nodes.insert(NodeName(init));
|
||||||
|
}
|
||||||
|
|
||||||
// Don't replicate all input nodes, except the dequeue node.
|
// Don't replicate all input nodes, except the dequeue node.
|
||||||
for (const auto& input_node : input_nodes) {
|
for (const auto& input_node : input_nodes) {
|
||||||
if (input_node->name() != dequeue_node->name()) {
|
if (input_node->name() != dequeue_node->name()) {
|
||||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||||||
#define TENSORFLOW_GRAPPLER_OPTIMIZERS_AUTO_PARALLEL_H_
|
#define TENSORFLOW_GRAPPLER_OPTIMIZERS_AUTO_PARALLEL_H_
|
||||||
|
|
||||||
#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
|
#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
|
||||||
|
#include "tensorflow/core/framework/variable.pb.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
@ -33,6 +33,7 @@ TEST_F(AutoParallelTest, SimpleParallel) {
|
|||||||
Output constant_b = ops::Const(s.WithOpName("constant_b"), 1, {1});
|
Output constant_b = ops::Const(s.WithOpName("constant_b"), 1, {1});
|
||||||
Output var = ops::Variable(s.WithOpName("var"), {1}, DT_FLOAT);
|
Output var = ops::Variable(s.WithOpName("var"), {1}, DT_FLOAT);
|
||||||
Output assign = ops::Assign(s.WithOpName("assign"), {var}, {constant_a});
|
Output assign = ops::Assign(s.WithOpName("assign"), {var}, {constant_a});
|
||||||
|
Output identity = ops::Identity(s.WithOpName("identity"), {var});
|
||||||
Output fifo_queue = ops::FIFOQueue(s.WithOpName("fifo_queue"), {DT_FLOAT});
|
Output fifo_queue = ops::FIFOQueue(s.WithOpName("fifo_queue"), {DT_FLOAT});
|
||||||
auto dequeue = ops::QueueDequeueMany(s.WithOpName("dequeue"), {fifo_queue},
|
auto dequeue = ops::QueueDequeueMany(s.WithOpName("dequeue"), {fifo_queue},
|
||||||
{constant_b}, {DT_FLOAT});
|
{constant_b}, {DT_FLOAT});
|
||||||
@ -44,13 +45,14 @@ TEST_F(AutoParallelTest, SimpleParallel) {
|
|||||||
GrapplerItem item;
|
GrapplerItem item;
|
||||||
item.init_ops.push_back("assign");
|
item.init_ops.push_back("assign");
|
||||||
item.fetch.push_back("apply_gradient");
|
item.fetch.push_back("apply_gradient");
|
||||||
|
item.init_ops.push_back("assign");
|
||||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||||
|
|
||||||
AutoParallel parallel(2);
|
AutoParallel parallel(2);
|
||||||
GraphDef output;
|
GraphDef output;
|
||||||
Status status = parallel.Optimize(nullptr, item, &output);
|
Status status = parallel.Optimize(nullptr, item, &output);
|
||||||
TF_EXPECT_OK(status);
|
TF_EXPECT_OK(status);
|
||||||
EXPECT_EQ(20, output.node_size());
|
EXPECT_EQ(21, output.node_size());
|
||||||
|
|
||||||
const NodeDef& node_assign = output.node(0);
|
const NodeDef& node_assign = output.node(0);
|
||||||
EXPECT_EQ("assign", node_assign.name());
|
EXPECT_EQ("assign", node_assign.name());
|
||||||
@ -62,60 +64,64 @@ TEST_F(AutoParallelTest, SimpleParallel) {
|
|||||||
const NodeDef& node_fifo_queue = output.node(2);
|
const NodeDef& node_fifo_queue = output.node(2);
|
||||||
EXPECT_EQ("fifo_queue", node_fifo_queue.name());
|
EXPECT_EQ("fifo_queue", node_fifo_queue.name());
|
||||||
|
|
||||||
const NodeDef& node_var = output.node(3);
|
const NodeDef& node_identity = output.node(3);
|
||||||
|
EXPECT_EQ("identity", node_identity.name());
|
||||||
|
EXPECT_EQ("var", node_identity.input(0));
|
||||||
|
|
||||||
|
const NodeDef& node_var = output.node(4);
|
||||||
EXPECT_EQ("var", node_var.name());
|
EXPECT_EQ("var", node_var.name());
|
||||||
|
|
||||||
const NodeDef& node_div_const0 = output.node(4);
|
const NodeDef& node_div_const0 = output.node(5);
|
||||||
EXPECT_EQ("AutoParallel-Replica-0/AutoParallel-Div-Const",
|
EXPECT_EQ("AutoParallel-Replica-0/AutoParallel-Div-Const",
|
||||||
node_div_const0.name());
|
node_div_const0.name());
|
||||||
|
|
||||||
const NodeDef& node_div0 = output.node(5);
|
const NodeDef& node_div0 = output.node(6);
|
||||||
EXPECT_EQ("AutoParallel-Replica-0/AutoParallel-Div-apply_gradient",
|
EXPECT_EQ("AutoParallel-Replica-0/AutoParallel-Div-apply_gradient",
|
||||||
node_div0.name());
|
node_div0.name());
|
||||||
const NodeDef& node_add0 = output.node(6);
|
const NodeDef& node_add0 = output.node(7);
|
||||||
EXPECT_EQ("AutoParallel-Replica-0/add", node_add0.name());
|
EXPECT_EQ("AutoParallel-Replica-0/add", node_add0.name());
|
||||||
|
|
||||||
const NodeDef& node_gradient0 = output.node(7);
|
const NodeDef& node_gradient0 = output.node(8);
|
||||||
EXPECT_EQ("AutoParallel-Replica-0/apply_gradient", node_gradient0.name());
|
EXPECT_EQ("AutoParallel-Replica-0/apply_gradient", node_gradient0.name());
|
||||||
|
|
||||||
const NodeDef& node_constant_a0 = output.node(8);
|
const NodeDef& node_constant_a0 = output.node(9);
|
||||||
EXPECT_EQ("AutoParallel-Replica-0/constant_a", node_constant_a0.name());
|
EXPECT_EQ("AutoParallel-Replica-0/constant_a", node_constant_a0.name());
|
||||||
|
|
||||||
const NodeDef& node_dequeue0 = output.node(9);
|
const NodeDef& node_dequeue0 = output.node(10);
|
||||||
EXPECT_EQ("AutoParallel-Replica-0/dequeue", node_dequeue0.name());
|
EXPECT_EQ("AutoParallel-Replica-0/dequeue", node_dequeue0.name());
|
||||||
|
|
||||||
const NodeDef& node_learning_rate0 = output.node(10);
|
const NodeDef& node_learning_rate0 = output.node(11);
|
||||||
EXPECT_EQ("AutoParallel-Replica-0/learning_rate", node_learning_rate0.name());
|
EXPECT_EQ("AutoParallel-Replica-0/learning_rate", node_learning_rate0.name());
|
||||||
|
|
||||||
const NodeDef& node_div_const1 = output.node(11);
|
const NodeDef& node_div_const1 = output.node(12);
|
||||||
EXPECT_EQ("AutoParallel-Replica-1/AutoParallel-Div-Const",
|
EXPECT_EQ("AutoParallel-Replica-1/AutoParallel-Div-Const",
|
||||||
node_div_const1.name());
|
node_div_const1.name());
|
||||||
|
|
||||||
const NodeDef& node_div1 = output.node(12);
|
const NodeDef& node_div1 = output.node(13);
|
||||||
EXPECT_EQ("AutoParallel-Replica-1/AutoParallel-Div-apply_gradient",
|
EXPECT_EQ("AutoParallel-Replica-1/AutoParallel-Div-apply_gradient",
|
||||||
node_div1.name());
|
node_div1.name());
|
||||||
|
|
||||||
const NodeDef& node_add1 = output.node(13);
|
const NodeDef& node_add1 = output.node(14);
|
||||||
EXPECT_EQ("AutoParallel-Replica-1/add", node_add1.name());
|
EXPECT_EQ("AutoParallel-Replica-1/add", node_add1.name());
|
||||||
|
|
||||||
const NodeDef& node_gradient1 = output.node(14);
|
const NodeDef& node_gradient1 = output.node(15);
|
||||||
EXPECT_EQ("AutoParallel-Replica-1/apply_gradient", node_gradient1.name());
|
EXPECT_EQ("AutoParallel-Replica-1/apply_gradient", node_gradient1.name());
|
||||||
|
|
||||||
const NodeDef& node_constant_a1 = output.node(15);
|
const NodeDef& node_constant_a1 = output.node(16);
|
||||||
EXPECT_EQ("AutoParallel-Replica-1/constant_a", node_constant_a1.name());
|
EXPECT_EQ("AutoParallel-Replica-1/constant_a", node_constant_a1.name());
|
||||||
|
|
||||||
const NodeDef& node_dequeue1 = output.node(16);
|
const NodeDef& node_dequeue1 = output.node(17);
|
||||||
EXPECT_EQ("AutoParallel-Replica-1/dequeue", node_dequeue1.name());
|
EXPECT_EQ("AutoParallel-Replica-1/dequeue", node_dequeue1.name());
|
||||||
|
|
||||||
const NodeDef& node_learning_rate1 = output.node(17);
|
const NodeDef& node_learning_rate1 = output.node(18);
|
||||||
EXPECT_EQ("AutoParallel-Replica-1/learning_rate", node_learning_rate1.name());
|
EXPECT_EQ("AutoParallel-Replica-1/learning_rate", node_learning_rate1.name());
|
||||||
|
|
||||||
const NodeDef& node_fetch = output.node(18);
|
const NodeDef& node_fetch = output.node(19);
|
||||||
EXPECT_EQ("AutoParallel-Control-Fetch", node_fetch.name());
|
EXPECT_EQ("AutoParallel-Control-Fetch", node_fetch.name());
|
||||||
EXPECT_EQ("^AutoParallel-Replica-0/apply_gradient", node_fetch.input(0));
|
EXPECT_EQ("^AutoParallel-Replica-0/apply_gradient", node_fetch.input(0));
|
||||||
EXPECT_EQ("^AutoParallel-Replica-1/apply_gradient", node_fetch.input(1));
|
EXPECT_EQ("^AutoParallel-Replica-1/apply_gradient", node_fetch.input(1));
|
||||||
|
|
||||||
const NodeDef& node_gradient = output.node(19);
|
const NodeDef& node_gradient = output.node(20);
|
||||||
EXPECT_EQ("apply_gradient", node_gradient.name());
|
EXPECT_EQ("apply_gradient", node_gradient.name());
|
||||||
EXPECT_EQ("^AutoParallel-Control-Fetch", node_gradient.input(0));
|
EXPECT_EQ("^AutoParallel-Control-Fetch", node_gradient.input(0));
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user