Copy function def to the optimized graph in the autoparallel optimizer.

Change: 153896372
This commit is contained in:
Yuefeng Zhou 2017-04-21 17:55:28 -08:00 committed by TensorFlower Gardener
parent e8482ab23b
commit 7bc6271055
4 changed files with 25 additions and 12 deletions

View File

@ -18,6 +18,13 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
bool IsDequeueOp(const NodeDef& node) {
static const std::set<std::string> dequeue_ops = {
"QueueDequeueManyV2", "QueueDequeueMany", "QueueDequeueV2",
"QueueDequeue"};
return dequeue_ops.count(node.op()) > 0;
}
bool IsPlaceholder(const NodeDef& node) {
const auto op = node.op();
return op == "Placeholder" || op == "PlaceholderV2";

View File

@ -21,6 +21,7 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
bool IsDequeueOp(const NodeDef& node);
bool IsPlaceholder(const NodeDef& node);
bool IsVariable(const NodeDef& node);

View File

@ -40,6 +40,7 @@ cc_library(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:devices",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/clusters:cluster",
],

View File

@ -14,11 +14,14 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/optimizers/auto_parallel.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/grappler/clusters/cluster.h"
#include "tensorflow/core/grappler/devices.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/lib/strings/strcat.h"
@ -94,22 +97,22 @@ Status AutoParallel::Initialize(const GrapplerItem& item) {
VLOG(2) << "Variable: " << var->name();
}
std::set<string> apply_gradients_ops = {"ApplyGradientDescent",
"ApplyProximalGradientDescent",
"ApplyAdadelta",
"ApplyAdagrad",
"ApplyProximalAdagrad",
"ApplyAdagradDA",
"ApplyFtrl",
"ApplyMomentum",
"ApplyAdam",
"ApplyRMSProp",
"ApplyCenteredRMSProp"};
const std::set<string> apply_gradients_ops = {"ApplyGradientDescent",
"ApplyProximalGradientDescent",
"ApplyAdadelta",
"ApplyAdagrad",
"ApplyProximalAdagrad",
"ApplyAdagradDA",
"ApplyFtrl",
"ApplyMomentum",
"ApplyAdam",
"ApplyRMSProp",
"ApplyCenteredRMSProp"};
const NodeDef* dequeue_node = nullptr;
for (int i = 0; i < graph_.node_size(); i++) {
all_nodes_.insert(
std::make_pair(graph_.node(i).name(), graph_.mutable_node(i)));
if (graph_.node(i).op() == "QueueDequeueManyV2") {
if (IsDequeueOp(graph_.node(i))) {
dequeue_node = graph_.mutable_node(i);
}
if (apply_gradients_ops.find(graph_.node(i).op()) !=
@ -241,6 +244,7 @@ void AutoParallel::BuildGraph(GraphDef* graph) {
for (const auto& fetch : item_->fetch) {
AddNodeControl(fetch, {control->name()}, graph);
}
*(graph->mutable_library()) = item_->graph.library();
LOG(INFO) << "Parallelized graph size: " << graph->node_size();
}