Copy function def to the optimized graph in the autoparallel optimizer.
Change: 153896372
This commit is contained in:
parent
e8482ab23b
commit
7bc6271055
tensorflow/core/grappler
@ -18,6 +18,13 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace grappler {
|
namespace grappler {
|
||||||
|
|
||||||
|
bool IsDequeueOp(const NodeDef& node) {
|
||||||
|
static const std::set<std::string> dequeue_ops = {
|
||||||
|
"QueueDequeueManyV2", "QueueDequeueMany", "QueueDequeueV2",
|
||||||
|
"QueueDequeue"};
|
||||||
|
return dequeue_ops.count(node.op()) > 0;
|
||||||
|
}
|
||||||
|
|
||||||
bool IsPlaceholder(const NodeDef& node) {
|
bool IsPlaceholder(const NodeDef& node) {
|
||||||
const auto op = node.op();
|
const auto op = node.op();
|
||||||
return op == "Placeholder" || op == "PlaceholderV2";
|
return op == "Placeholder" || op == "PlaceholderV2";
|
||||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace grappler {
|
namespace grappler {
|
||||||
|
|
||||||
|
bool IsDequeueOp(const NodeDef& node);
|
||||||
bool IsPlaceholder(const NodeDef& node);
|
bool IsPlaceholder(const NodeDef& node);
|
||||||
bool IsVariable(const NodeDef& node);
|
bool IsVariable(const NodeDef& node);
|
||||||
|
|
||||||
|
@ -40,6 +40,7 @@ cc_library(
|
|||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core/grappler:devices",
|
"//tensorflow/core/grappler:devices",
|
||||||
"//tensorflow/core/grappler:grappler_item",
|
"//tensorflow/core/grappler:grappler_item",
|
||||||
|
"//tensorflow/core/grappler:op_types",
|
||||||
"//tensorflow/core/grappler:utils",
|
"//tensorflow/core/grappler:utils",
|
||||||
"//tensorflow/core/grappler/clusters:cluster",
|
"//tensorflow/core/grappler/clusters:cluster",
|
||||||
],
|
],
|
||||||
|
@ -14,11 +14,14 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/core/grappler/optimizers/auto_parallel.h"
|
#include "tensorflow/core/grappler/optimizers/auto_parallel.h"
|
||||||
|
|
||||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||||
|
#include "tensorflow/core/framework/function.pb.h"
|
||||||
#include "tensorflow/core/framework/node_def.pb.h"
|
#include "tensorflow/core/framework/node_def.pb.h"
|
||||||
#include "tensorflow/core/grappler/clusters/cluster.h"
|
#include "tensorflow/core/grappler/clusters/cluster.h"
|
||||||
#include "tensorflow/core/grappler/devices.h"
|
#include "tensorflow/core/grappler/devices.h"
|
||||||
#include "tensorflow/core/grappler/grappler_item.h"
|
#include "tensorflow/core/grappler/grappler_item.h"
|
||||||
|
#include "tensorflow/core/grappler/op_types.h"
|
||||||
#include "tensorflow/core/grappler/utils.h"
|
#include "tensorflow/core/grappler/utils.h"
|
||||||
#include "tensorflow/core/lib/strings/strcat.h"
|
#include "tensorflow/core/lib/strings/strcat.h"
|
||||||
|
|
||||||
@ -94,7 +97,7 @@ Status AutoParallel::Initialize(const GrapplerItem& item) {
|
|||||||
VLOG(2) << "Variable: " << var->name();
|
VLOG(2) << "Variable: " << var->name();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::set<string> apply_gradients_ops = {"ApplyGradientDescent",
|
const std::set<string> apply_gradients_ops = {"ApplyGradientDescent",
|
||||||
"ApplyProximalGradientDescent",
|
"ApplyProximalGradientDescent",
|
||||||
"ApplyAdadelta",
|
"ApplyAdadelta",
|
||||||
"ApplyAdagrad",
|
"ApplyAdagrad",
|
||||||
@ -109,7 +112,7 @@ Status AutoParallel::Initialize(const GrapplerItem& item) {
|
|||||||
for (int i = 0; i < graph_.node_size(); i++) {
|
for (int i = 0; i < graph_.node_size(); i++) {
|
||||||
all_nodes_.insert(
|
all_nodes_.insert(
|
||||||
std::make_pair(graph_.node(i).name(), graph_.mutable_node(i)));
|
std::make_pair(graph_.node(i).name(), graph_.mutable_node(i)));
|
||||||
if (graph_.node(i).op() == "QueueDequeueManyV2") {
|
if (IsDequeueOp(graph_.node(i))) {
|
||||||
dequeue_node = graph_.mutable_node(i);
|
dequeue_node = graph_.mutable_node(i);
|
||||||
}
|
}
|
||||||
if (apply_gradients_ops.find(graph_.node(i).op()) !=
|
if (apply_gradients_ops.find(graph_.node(i).op()) !=
|
||||||
@ -241,6 +244,7 @@ void AutoParallel::BuildGraph(GraphDef* graph) {
|
|||||||
for (const auto& fetch : item_->fetch) {
|
for (const auto& fetch : item_->fetch) {
|
||||||
AddNodeControl(fetch, {control->name()}, graph);
|
AddNodeControl(fetch, {control->name()}, graph);
|
||||||
}
|
}
|
||||||
|
*(graph->mutable_library()) = item_->graph.library();
|
||||||
LOG(INFO) << "Parallelized graph size: " << graph->node_size();
|
LOG(INFO) << "Parallelized graph size: " << graph->node_size();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user