diff --git a/tensorflow/core/framework/node_def_builder.cc b/tensorflow/core/framework/node_def_builder.cc index 6a25114e6dc..58f79bd3657 100644 --- a/tensorflow/core/framework/node_def_builder.cc +++ b/tensorflow/core/framework/node_def_builder.cc @@ -211,7 +211,7 @@ NodeDefBuilder& NodeDefBuilder::Device(StringPiece device_spec) { return *this; } -Status NodeDefBuilder::Finalize(NodeDef* node_def) const { +Status NodeDefBuilder::Finalize(NodeDef* node_def, bool consume) { const std::vector<string>* errors_ptr = &errors_; std::vector<string> errors_storage; if (op_def_ != nullptr && inputs_specified_ < op_def_->input_arg_size()) { @@ -243,7 +243,11 @@ Status NodeDefBuilder::Finalize(NodeDef* node_def) const { } else { NodeDef node_def_backup; if (node_def == nullptr) node_def = &node_def_backup; - *node_def = node_def_; + if (consume) { + *node_def = std::move(node_def_); + } else { + *node_def = node_def_; + } // Add control inputs after the regular inputs. for (const auto& control_input : control_inputs_) { diff --git a/tensorflow/core/framework/node_def_builder.h b/tensorflow/core/framework/node_def_builder.h index 63d856d16c6..92d6399d1e2 100644 --- a/tensorflow/core/framework/node_def_builder.h +++ b/tensorflow/core/framework/node_def_builder.h @@ -129,9 +129,11 @@ class NodeDefBuilder { // Finish building the NodeDef, returning any errors or setting // *node_def if none. + // If `consume` is true, the builder state will be moved into `node_def`, + // and the builder will be left in an undefined state. // WARNING: Not all problems are detected! The resulting NodeDef may // not be valid! Call ValidateNodeDef() from node_def_utils to be sure. - Status Finalize(NodeDef* node_def) const; + Status Finalize(NodeDef* node_def, bool consume = false); // Accessors for the values set in the constructor. const string& node_name() const { return node_def_.name(); } diff --git a/tensorflow/core/framework/node_def_builder_test.cc b/tensorflow/core/framework/node_def_builder_test.cc index 7c4426e276a..d93f8e9e2d8 100644 --- a/tensorflow/core/framework/node_def_builder_test.cc +++ b/tensorflow/core/framework/node_def_builder_test.cc @@ -48,7 +48,7 @@ class NodeDefBuilderTest : public ::testing::Test { // Calls Finalize() and verifies it returns success and the result matches // expectations. - void ExpectSuccess(const NodeDefBuilder& builder, + void ExpectSuccess(NodeDefBuilder& builder, // NOLINT DataTypeSlice expected_in_types, DataTypeSlice expected_out_types, StringPiece proto) { NodeDef node_def; @@ -76,7 +76,7 @@ class NodeDefBuilderTest : public ::testing::Test { // Calls Finalize() and verifies it returns an error. // Each message must appear as a substring of the error. - void ExpectFailures(const NodeDefBuilder& builder, + void ExpectFailures(NodeDefBuilder& builder, // NOLINT const std::vector<string>& messages) { NodeDef node_def; Status status = builder.Finalize(&node_def); @@ -90,13 +90,15 @@ class NodeDefBuilderTest : public ::testing::Test { // Calls Finalize() and verifies it returns an error. // Message must appear as a substring of the error. - void ExpectFailure(const NodeDefBuilder& builder, const string& message) { + void ExpectFailure(NodeDefBuilder& builder, // NOLINT + const string& message) { ExpectFailures(builder, {message}); } // Like ExpectFailure(), except that the error can come from // ValidateNodeDef(). - void ExpectInvalid(const NodeDefBuilder& builder, const string& message) { + void ExpectInvalid(NodeDefBuilder& builder, // NOLINT + const string& message) { NodeDef node_def; Status status = builder.Finalize(&node_def); if (status.ok()) { diff --git a/tensorflow/core/framework/node_def_util_test.cc b/tensorflow/core/framework/node_def_util_test.cc index 4c4f0e2f37a..0817eb3a4e9 100644 --- a/tensorflow/core/framework/node_def_util_test.cc +++ b/tensorflow/core/framework/node_def_util_test.cc @@ -43,7 +43,7 @@ NodeDef ToNodeDef(const string& text) { return node_def; } -NodeDef ToNodeDef(const NodeDefBuilder& builder) { +NodeDef ToNodeDef(NodeDefBuilder&& builder) { NodeDef node_def; TF_EXPECT_OK(builder.Finalize(&node_def)); return node_def; @@ -244,14 +244,14 @@ TEST(NodeDefUtilTest, AnyIn) { TEST(NodeDefUtilTest, Device) { const OpDef op_def1 = ToOpDef(OpDefBuilder("None")); const NodeDef node_def1 = - ToNodeDef(NodeDefBuilder("d", &op_def1).Device("/cpu:17")); + ToNodeDef(std::move(NodeDefBuilder("d", &op_def1).Device("/cpu:17"))); ExpectSuccess(node_def1, op_def1); EXPECT_EQ("{{node d}} = None[_device=\"/cpu:17\"]()", SummarizeNodeDef(node_def1)); const OpDef op_def2 = ToOpDef(OpDefBuilder("WithAttr").Attr("v: int")); - const NodeDef node_def2 = - ToNodeDef(NodeDefBuilder("d", &op_def2).Attr("v", 7).Device("/cpu:5")); + const NodeDef node_def2 = ToNodeDef( + std::move(NodeDefBuilder("d", &op_def2).Attr("v", 7).Device("/cpu:5"))); ExpectSuccess(node_def2, op_def2); EXPECT_EQ("{{node d}} = WithAttr[v=7, _device=\"/cpu:5\"]()", SummarizeNodeDef(node_def2)); @@ -376,8 +376,8 @@ TEST(InputTypesForNode, Simple) { .Input("b: int32") .Output("c: string") .Output("d: bool")); - const NodeDef node_def = ToNodeDef( - NodeDefBuilder("simple", &op_def).Input(FakeInput()).Input(FakeInput())); + const NodeDef node_def = ToNodeDef(std::move( + NodeDefBuilder("simple", &op_def).Input(FakeInput()).Input(FakeInput()))); DataTypeVector types; EXPECT_TRUE(InputTypesForNode(node_def, op_def, &types).ok()); EXPECT_EQ(types[0], DT_FLOAT); @@ -397,8 +397,8 @@ TEST(OutputTypesForNode, Simple) { .Input("b: int32") .Output("c: string") .Output("d: bool")); - const NodeDef node_def = ToNodeDef( - NodeDefBuilder("simple", &op_def).Input(FakeInput()).Input(FakeInput())); + const NodeDef node_def = ToNodeDef(std::move( + NodeDefBuilder("simple", &op_def).Input(FakeInput()).Input(FakeInput()))); DataTypeVector types; EXPECT_TRUE(OutputTypesForNode(node_def, op_def, &types).ok()); EXPECT_EQ(types[0], DT_STRING); @@ -418,8 +418,10 @@ TEST(OutputTypesForNode_AttrSliceOverload, Simple) { .Input("b: int32") .Output("c: string") .Output("d: bool")); - const AttrSlice attr_slice = AttrSlice(ToNodeDef( - NodeDefBuilder("simple", &op_def).Input(FakeInput()).Input(FakeInput()))); + const AttrSlice attr_slice = + AttrSlice(ToNodeDef(std::move(NodeDefBuilder("simple", &op_def) + .Input(FakeInput()) + .Input(FakeInput())))); DataTypeVector types; EXPECT_TRUE(OutputTypesForNode(attr_slice, op_def, &types).ok()); EXPECT_EQ(types[0], DT_STRING); @@ -433,8 +435,8 @@ TEST(NameRangesForNodeTest, Simple) { .Output("c: string") .Output("d: bool")); NameRangeMap inputs, outputs; - const NodeDef node_def = ToNodeDef( - NodeDefBuilder("simple", &op_def).Input(FakeInput()).Input(FakeInput())); + const NodeDef node_def = ToNodeDef(std::move( + NodeDefBuilder("simple", &op_def).Input(FakeInput()).Input(FakeInput()))); TF_EXPECT_OK(NameRangesForNode(node_def, op_def, &inputs, &outputs)); EXPECT_EQ(NameRangeMap({{"a", {0, 1}}, {"b", {1, 2}}}), inputs); EXPECT_EQ(NameRangeMap({{"c", {0, 1}}, {"d", {1, 2}}}), outputs); @@ -453,18 +455,20 @@ TEST(NameRangesForNodeTest, Polymorphic) { .Output("c: T") .Attr("T: type")); NameRangeMap inputs, outputs; - const NodeDef node_def1 = ToNodeDef(NodeDefBuilder("poly", &op_def) - .Input(FakeInput(DT_INT32)) - .Input(FakeInput(DT_INT32))); + const NodeDef node_def1 = + ToNodeDef(std::move(NodeDefBuilder("poly", &op_def) + .Input(FakeInput(DT_INT32)) + .Input(FakeInput(DT_INT32)))); TF_EXPECT_OK(NameRangesForNode(node_def1, op_def, &inputs, &outputs)); EXPECT_EQ(NameRangeMap({{"a", {0, 1}}, {"b", {1, 2}}}), inputs); EXPECT_EQ(NameRangeMap({{"c", {0, 1}}}), outputs); EXPECT_EQ("{{node poly}} = Polymorphic[T=DT_INT32](a, b)", SummarizeNodeDef(node_def1)); - const NodeDef node_def2 = ToNodeDef(NodeDefBuilder("poly", &op_def) - .Input(FakeInput(DT_BOOL)) - .Input(FakeInput(DT_BOOL))); + const NodeDef node_def2 = + ToNodeDef(std::move(NodeDefBuilder("poly", &op_def) + .Input(FakeInput(DT_BOOL)) + .Input(FakeInput(DT_BOOL)))); TF_EXPECT_OK(NameRangesForNode(node_def2, op_def, &inputs, &outputs)); EXPECT_EQ(NameRangeMap({{"a", {0, 1}}, {"b", {1, 2}}}), inputs); EXPECT_EQ(NameRangeMap({{"c", {0, 1}}}), outputs); @@ -483,10 +487,11 @@ TEST(NameRangesForNodeTest, NRepeats) { .Attr("M: int") .Attr("T: type")); NameRangeMap inputs, outputs; - const NodeDef node_def1 = ToNodeDef(NodeDefBuilder("nr", &op_def) - .Input(FakeInput(4, DT_INT32)) - .Input(FakeInput(4, DT_FLOAT)) - .Attr("M", 3)); + const NodeDef node_def1 = + ToNodeDef(std::move(NodeDefBuilder("nr", &op_def) + .Input(FakeInput(4, DT_INT32)) + .Input(FakeInput(4, DT_FLOAT)) + .Attr("M", 3))); TF_EXPECT_OK(NameRangesForNode(node_def1, op_def, &inputs, &outputs)); EXPECT_EQ(NameRangeMap({{"a", {0, 4}}, {"b", {4, 8}}}), inputs); EXPECT_EQ(NameRangeMap({{"c", {0, 1}}, {"d", {1, 5}}, {"e", {5, 8}}}), @@ -496,10 +501,11 @@ TEST(NameRangesForNodeTest, NRepeats) { "b:2, b:3)", SummarizeNodeDef(node_def1)); - const NodeDef node_def2 = ToNodeDef(NodeDefBuilder("nr", &op_def) - .Input(FakeInput(2, DT_INT32)) - .Input(FakeInput(2, DT_DOUBLE)) - .Attr("M", 7)); + const NodeDef node_def2 = + ToNodeDef(std::move(NodeDefBuilder("nr", &op_def) + .Input(FakeInput(2, DT_INT32)) + .Input(FakeInput(2, DT_DOUBLE)) + .Attr("M", 7))); TF_EXPECT_OK(NameRangesForNode(node_def2, op_def, &inputs, &outputs)); EXPECT_EQ(NameRangeMap({{"a", {0, 2}}, {"b", {2, 4}}}), inputs); EXPECT_EQ(NameRangeMap({{"c", {0, 1}}, {"d", {1, 3}}, {"e", {3, 10}}}), @@ -524,10 +530,10 @@ TEST(NameRangesForNodeTest, TypeList) { .Attr("T3: list(type)")); NameRangeMap inputs, outputs; const NodeDef node_def1 = - ToNodeDef(NodeDefBuilder("tl", &op_def) - .Input(FakeInput({DT_BOOL, DT_FLOAT})) - .Input(FakeInput(4, DT_FLOAT)) - .Attr("T3", {DT_INT32, DT_DOUBLE, DT_STRING})); + ToNodeDef(std::move(NodeDefBuilder("tl", &op_def) + .Input(FakeInput({DT_BOOL, DT_FLOAT})) + .Input(FakeInput(4, DT_FLOAT)) + .Attr("T3", {DT_INT32, DT_DOUBLE, DT_STRING}))); TF_EXPECT_OK(NameRangesForNode(node_def1, op_def, &inputs, &outputs)); EXPECT_EQ(NameRangeMap({{"a", {0, 2}}, {"b", {2, 6}}}), inputs); EXPECT_EQ(NameRangeMap({{"c", {0, 4}}, {"d", {4, 7}}, {"e", {7, 9}}}), @@ -538,10 +544,11 @@ TEST(NameRangesForNodeTest, TypeList) { " T3=[DT_INT32, DT_DOUBLE, DT_STRING]](a, a:1, b, b:1, b:2, b:3)", SummarizeNodeDef(node_def1)); - const NodeDef node_def2 = ToNodeDef(NodeDefBuilder("tl", &op_def) - .Input(FakeInput(7, DT_INT32)) - .Input(FakeInput({DT_DOUBLE})) - .Attr("T3", {DT_DOUBLE, DT_STRING})); + const NodeDef node_def2 = + ToNodeDef(std::move(NodeDefBuilder("tl", &op_def) + .Input(FakeInput(7, DT_INT32)) + .Input(FakeInput({DT_DOUBLE})) + .Attr("T3", {DT_DOUBLE, DT_STRING}))); TF_EXPECT_OK(NameRangesForNode(node_def2, op_def, &inputs, &outputs)); EXPECT_EQ(NameRangeMap({{"a", {0, 7}}, {"b", {7, 8}}}), inputs); EXPECT_EQ(NameRangeMap({{"c", {0, 1}}, {"d", {1, 3}}, {"e", {3, 10}}}), diff --git a/tensorflow/core/graph/graph_partition.cc b/tensorflow/core/graph/graph_partition.cc index a13769b3315..1c906a3599c 100644 --- a/tensorflow/core/graph/graph_partition.cc +++ b/tensorflow/core/graph/graph_partition.cc @@ -227,7 +227,7 @@ NodeDef* AddSend(const PartitionOptions& opts, const GraphInfo& g_info, } NodeDef* cast = gdef->add_node(); - *status = cast_builder.Finalize(cast); + *status = cast_builder.Finalize(cast, /*consume=*/true); if (!status->ok()) return nullptr; // Connect the Send op to the cast. @@ -244,7 +244,7 @@ NodeDef* AddSend(const PartitionOptions& opts, const GraphInfo& g_info, send_builder.Attr("_start_time", start_time); } NodeDef* send = gdef->add_node(); - *status = send_builder.Finalize(send); + *status = send_builder.Finalize(send, /*consume=*/true); return send; } @@ -301,7 +301,7 @@ NodeDef* AddRecv(const PartitionOptions& opts, const GraphInfo& g_info, recv_builder.Device(dst->assigned_device_name()) .Attr("tensor_type", cast_dtype); NodeDef* recv = gdef->add_node(); - *status = recv_builder.Finalize(recv); + *status = recv_builder.Finalize(recv, /*consume=*/true); if (!status->ok()) return nullptr; *real_recv = recv; @@ -314,7 +314,7 @@ NodeDef* AddRecv(const PartitionOptions& opts, const GraphInfo& g_info, cast_builder.Device(dst->assigned_device_name()) .Input(recv->name(), 0, cast_dtype); NodeDef* cast = gdef->add_node(); - *status = cast_builder.Finalize(cast); + *status = cast_builder.Finalize(cast, /*consume=*/true); if (!status->ok()) return nullptr; return cast; } else if (edge->IsControlEdge()) { @@ -324,7 +324,7 @@ NodeDef* AddRecv(const PartitionOptions& opts, const GraphInfo& g_info, id_builder.Device(dst->assigned_device_name()) .Input(recv->name(), 0, cast_dtype); NodeDef* id = gdef->add_node(); - *status = id_builder.Finalize(id); + *status = id_builder.Finalize(id, /*consume=*/true); if (!status->ok()) return nullptr; return id; } else { @@ -341,7 +341,7 @@ NodeDef* AddDummyConst(const PartitionOptions& opts, GraphDef* gdef, .Device(src->assigned_device_name()) .Attr("dtype", DT_FLOAT) .Attr("value", tensor) - .Finalize(result); + .Finalize(result, /*consume=*/true); return result; } @@ -354,7 +354,7 @@ NodeDef* AddControlTrigger(const PartitionOptions& opts, GraphDef* gdef, "ControlTrigger") .Device(assigned_device_name) .Attr("_start_time", starttime) - .Finalize(result); + .Finalize(result, /*consume=*/true); return result; } @@ -424,7 +424,7 @@ Node* AddControlEnter(Graph* g, const string& node_name, node_builder.Attr("frame_name", frame_name); node_builder.Attr("parallel_iterations", parallel_iterations); Node* res_node; - *status = node_builder.Finalize(g, &res_node); + *status = node_builder.Finalize(g, &res_node, /*consume=*/true); if (!status->ok()) return nullptr; res_node->set_assigned_device_name(device_name); return res_node; @@ -437,7 +437,7 @@ Node* AddControlMerge(const string& in_name1, const string& in_name2, Graph* g, NodeBuilder node_builder(node_name, "Merge", g->op_registry()); node_builder.Input({{in_name1, 0, DT_FLOAT}, {in_name2, 0, DT_FLOAT}}); Node* res_node; - *status = node_builder.Finalize(g, &res_node); + *status = node_builder.Finalize(g, &res_node, /*consume=*/true); if (!status->ok()) return nullptr; res_node->set_assigned_device_name(device_name); return res_node; diff --git a/tensorflow/core/graph/node_builder.cc b/tensorflow/core/graph/node_builder.cc index 6ce4531c5bc..07bf49f7f63 100644 --- a/tensorflow/core/graph/node_builder.cc +++ b/tensorflow/core/graph/node_builder.cc @@ -112,7 +112,7 @@ NodeBuilder& NodeBuilder::XlaCluster(StringPiece xla_cluster) { return *this; } -Status NodeBuilder::Finalize(Graph* graph, Node** created_node) const { +Status NodeBuilder::Finalize(Graph* graph, Node** created_node, bool consume) { // In case of error, set *created_node to nullptr. if (created_node != nullptr) *created_node = nullptr; if (!errors_.empty()) { @@ -120,7 +120,7 @@ Status NodeBuilder::Finalize(Graph* graph, Node** created_node) const { } NodeDef node_def; - TF_RETURN_IF_ERROR(def_builder_.Finalize(&node_def)); + TF_RETURN_IF_ERROR(def_builder_.Finalize(&node_def, consume)); TF_RETURN_IF_ERROR(ValidateNodeDef(node_def, def_builder_.op_def())); TF_RETURN_IF_ERROR( CheckOpDeprecation(def_builder_.op_def(), graph->versions().producer())); diff --git a/tensorflow/core/graph/node_builder.h b/tensorflow/core/graph/node_builder.h index 51e044cd8b2..ce4fb4f3c48 100644 --- a/tensorflow/core/graph/node_builder.h +++ b/tensorflow/core/graph/node_builder.h @@ -121,7 +121,9 @@ class NodeBuilder { // Validates the described node and adds it to *graph, adding edges // for all (non-back) inputs. If created_node is not nullptr, // *created_node will be set to the new node (or nullptr on error). - Status Finalize(Graph* graph, Node** created_node) const; + // If `consume` is true, the builder state will be moved into `node_def`, + // and the builder will be left in an undefined state. + Status Finalize(Graph* graph, Node** created_node, bool consume = false); // Accessors for the values set in the constructor. const string& node_name() const { return def_builder_.node_name(); } diff --git a/tensorflow/core/graph/subgraph.cc b/tensorflow/core/graph/subgraph.cc index 7d839723f89..e70427f9ef8 100644 --- a/tensorflow/core/graph/subgraph.cc +++ b/tensorflow/core/graph/subgraph.cc @@ -229,7 +229,7 @@ Status ArgFeedRewrite::AddNode(Graph* g, NodeBuilder::NodeOut feed_tensor, "_Arg") .Attr("T", BaseType(feed_tensor.node->output_type(feed_tensor.index))) .Attr("index", arg_index_) - .Finalize(g, out_node)); + .Finalize(g, out_node, /*consume=*/true)); (*out_node)->set_assigned_device_name(device_info().name()); return Status::OK(); } @@ -248,7 +248,7 @@ Status RecvFeedRewrite::AddNode(Graph* g, NodeBuilder::NodeOut feed_tensor, .Attr("send_device_incarnation", static_cast<int64>(device_info().incarnation())) .Attr("client_terminated", true) - .Finalize(g, out_node)); + .Finalize(g, out_node, /*consume=*/true)); (*out_node)->set_assigned_device_name(device_info().name()); return Status::OK(); @@ -268,7 +268,7 @@ Status RetvalFetchRewrite::AddNode(Graph* g, NodeBuilder::NodeOut fetch_tensor, .Attr("T", BaseType(fetch_tensor.node->output_type(fetch_tensor.index))) .Attr("index", retval_index_) - .Finalize(g, out_node)); + .Finalize(g, out_node, /*consume=*/true)); (*out_node)->set_assigned_device_name(device_info().name()); return Status::OK(); } @@ -286,7 +286,7 @@ Status SendFetchRewrite::AddNode(Graph* g, NodeBuilder::NodeOut fetch_tensor, .Attr("send_device_incarnation", static_cast<int64>(device_info().incarnation())) .Attr("client_terminated", true) - .Finalize(g, out_node)); + .Finalize(g, out_node, /*consume=*/true)); (*out_node)->set_assigned_device_name(device_info().name()); return Status::OK(); }