Copy function def to the optimized graph in the autoparallel optimizer.

Change: 153896372
This commit is contained in:
Yuefeng Zhou 2017-04-21 17:55:28 -08:00 committed by TensorFlower Gardener
parent e8482ab23b
commit 7bc6271055
4 changed files with 25 additions and 12 deletions
tensorflow/core/grappler

View File

@ -18,6 +18,13 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
namespace grappler { namespace grappler {
bool IsDequeueOp(const NodeDef& node) {
static const std::set<std::string> dequeue_ops = {
"QueueDequeueManyV2", "QueueDequeueMany", "QueueDequeueV2",
"QueueDequeue"};
return dequeue_ops.count(node.op()) > 0;
}
bool IsPlaceholder(const NodeDef& node) { bool IsPlaceholder(const NodeDef& node) {
const auto op = node.op(); const auto op = node.op();
return op == "Placeholder" || op == "PlaceholderV2"; return op == "Placeholder" || op == "PlaceholderV2";

View File

@ -21,6 +21,7 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
namespace grappler { namespace grappler {
bool IsDequeueOp(const NodeDef& node);
bool IsPlaceholder(const NodeDef& node); bool IsPlaceholder(const NodeDef& node);
bool IsVariable(const NodeDef& node); bool IsVariable(const NodeDef& node);

View File

@ -40,6 +40,7 @@ cc_library(
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:devices", "//tensorflow/core/grappler:devices",
"//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils", "//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/clusters:cluster", "//tensorflow/core/grappler/clusters:cluster",
], ],

View File

@ -14,11 +14,14 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/core/grappler/optimizers/auto_parallel.h" #include "tensorflow/core/grappler/optimizers/auto_parallel.h"
#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/grappler/clusters/cluster.h" #include "tensorflow/core/grappler/clusters/cluster.h"
#include "tensorflow/core/grappler/devices.h" #include "tensorflow/core/grappler/devices.h"
#include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/strcat.h"
@ -94,7 +97,7 @@ Status AutoParallel::Initialize(const GrapplerItem& item) {
VLOG(2) << "Variable: " << var->name(); VLOG(2) << "Variable: " << var->name();
} }
std::set<string> apply_gradients_ops = {"ApplyGradientDescent", const std::set<string> apply_gradients_ops = {"ApplyGradientDescent",
"ApplyProximalGradientDescent", "ApplyProximalGradientDescent",
"ApplyAdadelta", "ApplyAdadelta",
"ApplyAdagrad", "ApplyAdagrad",
@ -109,7 +112,7 @@ Status AutoParallel::Initialize(const GrapplerItem& item) {
for (int i = 0; i < graph_.node_size(); i++) { for (int i = 0; i < graph_.node_size(); i++) {
all_nodes_.insert( all_nodes_.insert(
std::make_pair(graph_.node(i).name(), graph_.mutable_node(i))); std::make_pair(graph_.node(i).name(), graph_.mutable_node(i)));
if (graph_.node(i).op() == "QueueDequeueManyV2") { if (IsDequeueOp(graph_.node(i))) {
dequeue_node = graph_.mutable_node(i); dequeue_node = graph_.mutable_node(i);
} }
if (apply_gradients_ops.find(graph_.node(i).op()) != if (apply_gradients_ops.find(graph_.node(i).op()) !=
@ -241,6 +244,7 @@ void AutoParallel::BuildGraph(GraphDef* graph) {
for (const auto& fetch : item_->fetch) { for (const auto& fetch : item_->fetch) {
AddNodeControl(fetch, {control->name()}, graph); AddNodeControl(fetch, {control->name()}, graph);
} }
*(graph->mutable_library()) = item_->graph.library();
LOG(INFO) << "Parallelized graph size: " << graph->node_size(); LOG(INFO) << "Parallelized graph size: " << graph->node_size();
} }