Copy function def to the optimized graph in the autoparallel optimizer.
Change: 153896372
This commit is contained in:
parent
e8482ab23b
commit
7bc6271055
@ -18,6 +18,13 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
bool IsDequeueOp(const NodeDef& node) {
|
||||
static const std::set<std::string> dequeue_ops = {
|
||||
"QueueDequeueManyV2", "QueueDequeueMany", "QueueDequeueV2",
|
||||
"QueueDequeue"};
|
||||
return dequeue_ops.count(node.op()) > 0;
|
||||
}
|
||||
|
||||
bool IsPlaceholder(const NodeDef& node) {
|
||||
const auto op = node.op();
|
||||
return op == "Placeholder" || op == "PlaceholderV2";
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
bool IsDequeueOp(const NodeDef& node);
|
||||
bool IsPlaceholder(const NodeDef& node);
|
||||
bool IsVariable(const NodeDef& node);
|
||||
|
||||
|
@ -40,6 +40,7 @@ cc_library(
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/grappler:devices",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
"//tensorflow/core/grappler:op_types",
|
||||
"//tensorflow/core/grappler:utils",
|
||||
"//tensorflow/core/grappler/clusters:cluster",
|
||||
],
|
||||
|
@ -14,11 +14,14 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/optimizers/auto_parallel.h"
|
||||
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/grappler/clusters/cluster.h"
|
||||
#include "tensorflow/core/grappler/devices.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/op_types.h"
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
|
||||
@ -94,22 +97,22 @@ Status AutoParallel::Initialize(const GrapplerItem& item) {
|
||||
VLOG(2) << "Variable: " << var->name();
|
||||
}
|
||||
|
||||
std::set<string> apply_gradients_ops = {"ApplyGradientDescent",
|
||||
"ApplyProximalGradientDescent",
|
||||
"ApplyAdadelta",
|
||||
"ApplyAdagrad",
|
||||
"ApplyProximalAdagrad",
|
||||
"ApplyAdagradDA",
|
||||
"ApplyFtrl",
|
||||
"ApplyMomentum",
|
||||
"ApplyAdam",
|
||||
"ApplyRMSProp",
|
||||
"ApplyCenteredRMSProp"};
|
||||
const std::set<string> apply_gradients_ops = {"ApplyGradientDescent",
|
||||
"ApplyProximalGradientDescent",
|
||||
"ApplyAdadelta",
|
||||
"ApplyAdagrad",
|
||||
"ApplyProximalAdagrad",
|
||||
"ApplyAdagradDA",
|
||||
"ApplyFtrl",
|
||||
"ApplyMomentum",
|
||||
"ApplyAdam",
|
||||
"ApplyRMSProp",
|
||||
"ApplyCenteredRMSProp"};
|
||||
const NodeDef* dequeue_node = nullptr;
|
||||
for (int i = 0; i < graph_.node_size(); i++) {
|
||||
all_nodes_.insert(
|
||||
std::make_pair(graph_.node(i).name(), graph_.mutable_node(i)));
|
||||
if (graph_.node(i).op() == "QueueDequeueManyV2") {
|
||||
if (IsDequeueOp(graph_.node(i))) {
|
||||
dequeue_node = graph_.mutable_node(i);
|
||||
}
|
||||
if (apply_gradients_ops.find(graph_.node(i).op()) !=
|
||||
@ -241,6 +244,7 @@ void AutoParallel::BuildGraph(GraphDef* graph) {
|
||||
for (const auto& fetch : item_->fetch) {
|
||||
AddNodeControl(fetch, {control->name()}, graph);
|
||||
}
|
||||
*(graph->mutable_library()) = item_->graph.library();
|
||||
LOG(INFO) << "Parallelized graph size: " << graph->node_size();
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user