Allow users of Node[Def]Builder to avoid copying the created NodeDef on finalization.

By passing true as the optional `consume` argument, we can move the constructed NodeDef out of the NodeDefBuilder, which avoids a potentially large copy.

PiperOrigin-RevId: 259089263
This commit is contained in:
Derek Murray 2019-07-19 22:05:05 -07:00 committed by TensorFlower Gardener
parent 50885ca141
commit ed87a6ddc3
8 changed files with 74 additions and 57 deletions

View File

@ -211,7 +211,7 @@ NodeDefBuilder& NodeDefBuilder::Device(StringPiece device_spec) {
return *this;
}
Status NodeDefBuilder::Finalize(NodeDef* node_def) const {
Status NodeDefBuilder::Finalize(NodeDef* node_def, bool consume) {
const std::vector<string>* errors_ptr = &errors_;
std::vector<string> errors_storage;
if (op_def_ != nullptr && inputs_specified_ < op_def_->input_arg_size()) {
@ -243,7 +243,11 @@ Status NodeDefBuilder::Finalize(NodeDef* node_def) const {
} else {
NodeDef node_def_backup;
if (node_def == nullptr) node_def = &node_def_backup;
*node_def = node_def_;
if (consume) {
*node_def = std::move(node_def_);
} else {
*node_def = node_def_;
}
// Add control inputs after the regular inputs.
for (const auto& control_input : control_inputs_) {

View File

@ -129,9 +129,11 @@ class NodeDefBuilder {
// Finish building the NodeDef, returning any errors or setting
// *node_def if none.
// If `consume` is true, the builder state will be moved into `node_def`,
// and the builder will be left in an undefined state.
// WARNING: Not all problems are detected! The resulting NodeDef may
// not be valid! Call ValidateNodeDef() from node_def_utils to be sure.
Status Finalize(NodeDef* node_def) const;
Status Finalize(NodeDef* node_def, bool consume = false);
// Accessors for the values set in the constructor.
const string& node_name() const { return node_def_.name(); }

View File

@ -48,7 +48,7 @@ class NodeDefBuilderTest : public ::testing::Test {
// Calls Finalize() and verifies it returns success and the result matches
// expectations.
void ExpectSuccess(const NodeDefBuilder& builder,
void ExpectSuccess(NodeDefBuilder& builder, // NOLINT
DataTypeSlice expected_in_types,
DataTypeSlice expected_out_types, StringPiece proto) {
NodeDef node_def;
@ -76,7 +76,7 @@ class NodeDefBuilderTest : public ::testing::Test {
// Calls Finalize() and verifies it returns an error.
// Each message must appear as a substring of the error.
void ExpectFailures(const NodeDefBuilder& builder,
void ExpectFailures(NodeDefBuilder& builder, // NOLINT
const std::vector<string>& messages) {
NodeDef node_def;
Status status = builder.Finalize(&node_def);
@ -90,13 +90,15 @@ class NodeDefBuilderTest : public ::testing::Test {
// Calls Finalize() and verifies it returns an error.
// Message must appear as a substring of the error.
void ExpectFailure(const NodeDefBuilder& builder, const string& message) {
void ExpectFailure(NodeDefBuilder& builder, // NOLINT
const string& message) {
ExpectFailures(builder, {message});
}
// Like ExpectFailure(), except that the error can come from
// ValidateNodeDef().
void ExpectInvalid(const NodeDefBuilder& builder, const string& message) {
void ExpectInvalid(NodeDefBuilder& builder, // NOLINT
const string& message) {
NodeDef node_def;
Status status = builder.Finalize(&node_def);
if (status.ok()) {

View File

@ -43,7 +43,7 @@ NodeDef ToNodeDef(const string& text) {
return node_def;
}
NodeDef ToNodeDef(const NodeDefBuilder& builder) {
NodeDef ToNodeDef(NodeDefBuilder&& builder) {
NodeDef node_def;
TF_EXPECT_OK(builder.Finalize(&node_def));
return node_def;
@ -244,14 +244,14 @@ TEST(NodeDefUtilTest, AnyIn) {
TEST(NodeDefUtilTest, Device) {
const OpDef op_def1 = ToOpDef(OpDefBuilder("None"));
const NodeDef node_def1 =
ToNodeDef(NodeDefBuilder("d", &op_def1).Device("/cpu:17"));
ToNodeDef(std::move(NodeDefBuilder("d", &op_def1).Device("/cpu:17")));
ExpectSuccess(node_def1, op_def1);
EXPECT_EQ("{{node d}} = None[_device=\"/cpu:17\"]()",
SummarizeNodeDef(node_def1));
const OpDef op_def2 = ToOpDef(OpDefBuilder("WithAttr").Attr("v: int"));
const NodeDef node_def2 =
ToNodeDef(NodeDefBuilder("d", &op_def2).Attr("v", 7).Device("/cpu:5"));
const NodeDef node_def2 = ToNodeDef(
std::move(NodeDefBuilder("d", &op_def2).Attr("v", 7).Device("/cpu:5")));
ExpectSuccess(node_def2, op_def2);
EXPECT_EQ("{{node d}} = WithAttr[v=7, _device=\"/cpu:5\"]()",
SummarizeNodeDef(node_def2));
@ -376,8 +376,8 @@ TEST(InputTypesForNode, Simple) {
.Input("b: int32")
.Output("c: string")
.Output("d: bool"));
const NodeDef node_def = ToNodeDef(
NodeDefBuilder("simple", &op_def).Input(FakeInput()).Input(FakeInput()));
const NodeDef node_def = ToNodeDef(std::move(
NodeDefBuilder("simple", &op_def).Input(FakeInput()).Input(FakeInput())));
DataTypeVector types;
EXPECT_TRUE(InputTypesForNode(node_def, op_def, &types).ok());
EXPECT_EQ(types[0], DT_FLOAT);
@ -397,8 +397,8 @@ TEST(OutputTypesForNode, Simple) {
.Input("b: int32")
.Output("c: string")
.Output("d: bool"));
const NodeDef node_def = ToNodeDef(
NodeDefBuilder("simple", &op_def).Input(FakeInput()).Input(FakeInput()));
const NodeDef node_def = ToNodeDef(std::move(
NodeDefBuilder("simple", &op_def).Input(FakeInput()).Input(FakeInput())));
DataTypeVector types;
EXPECT_TRUE(OutputTypesForNode(node_def, op_def, &types).ok());
EXPECT_EQ(types[0], DT_STRING);
@ -418,8 +418,10 @@ TEST(OutputTypesForNode_AttrSliceOverload, Simple) {
.Input("b: int32")
.Output("c: string")
.Output("d: bool"));
const AttrSlice attr_slice = AttrSlice(ToNodeDef(
NodeDefBuilder("simple", &op_def).Input(FakeInput()).Input(FakeInput())));
const AttrSlice attr_slice =
AttrSlice(ToNodeDef(std::move(NodeDefBuilder("simple", &op_def)
.Input(FakeInput())
.Input(FakeInput()))));
DataTypeVector types;
EXPECT_TRUE(OutputTypesForNode(attr_slice, op_def, &types).ok());
EXPECT_EQ(types[0], DT_STRING);
@ -433,8 +435,8 @@ TEST(NameRangesForNodeTest, Simple) {
.Output("c: string")
.Output("d: bool"));
NameRangeMap inputs, outputs;
const NodeDef node_def = ToNodeDef(
NodeDefBuilder("simple", &op_def).Input(FakeInput()).Input(FakeInput()));
const NodeDef node_def = ToNodeDef(std::move(
NodeDefBuilder("simple", &op_def).Input(FakeInput()).Input(FakeInput())));
TF_EXPECT_OK(NameRangesForNode(node_def, op_def, &inputs, &outputs));
EXPECT_EQ(NameRangeMap({{"a", {0, 1}}, {"b", {1, 2}}}), inputs);
EXPECT_EQ(NameRangeMap({{"c", {0, 1}}, {"d", {1, 2}}}), outputs);
@ -453,18 +455,20 @@ TEST(NameRangesForNodeTest, Polymorphic) {
.Output("c: T")
.Attr("T: type"));
NameRangeMap inputs, outputs;
const NodeDef node_def1 = ToNodeDef(NodeDefBuilder("poly", &op_def)
.Input(FakeInput(DT_INT32))
.Input(FakeInput(DT_INT32)));
const NodeDef node_def1 =
ToNodeDef(std::move(NodeDefBuilder("poly", &op_def)
.Input(FakeInput(DT_INT32))
.Input(FakeInput(DT_INT32))));
TF_EXPECT_OK(NameRangesForNode(node_def1, op_def, &inputs, &outputs));
EXPECT_EQ(NameRangeMap({{"a", {0, 1}}, {"b", {1, 2}}}), inputs);
EXPECT_EQ(NameRangeMap({{"c", {0, 1}}}), outputs);
EXPECT_EQ("{{node poly}} = Polymorphic[T=DT_INT32](a, b)",
SummarizeNodeDef(node_def1));
const NodeDef node_def2 = ToNodeDef(NodeDefBuilder("poly", &op_def)
.Input(FakeInput(DT_BOOL))
.Input(FakeInput(DT_BOOL)));
const NodeDef node_def2 =
ToNodeDef(std::move(NodeDefBuilder("poly", &op_def)
.Input(FakeInput(DT_BOOL))
.Input(FakeInput(DT_BOOL))));
TF_EXPECT_OK(NameRangesForNode(node_def2, op_def, &inputs, &outputs));
EXPECT_EQ(NameRangeMap({{"a", {0, 1}}, {"b", {1, 2}}}), inputs);
EXPECT_EQ(NameRangeMap({{"c", {0, 1}}}), outputs);
@ -483,10 +487,11 @@ TEST(NameRangesForNodeTest, NRepeats) {
.Attr("M: int")
.Attr("T: type"));
NameRangeMap inputs, outputs;
const NodeDef node_def1 = ToNodeDef(NodeDefBuilder("nr", &op_def)
.Input(FakeInput(4, DT_INT32))
.Input(FakeInput(4, DT_FLOAT))
.Attr("M", 3));
const NodeDef node_def1 =
ToNodeDef(std::move(NodeDefBuilder("nr", &op_def)
.Input(FakeInput(4, DT_INT32))
.Input(FakeInput(4, DT_FLOAT))
.Attr("M", 3)));
TF_EXPECT_OK(NameRangesForNode(node_def1, op_def, &inputs, &outputs));
EXPECT_EQ(NameRangeMap({{"a", {0, 4}}, {"b", {4, 8}}}), inputs);
EXPECT_EQ(NameRangeMap({{"c", {0, 1}}, {"d", {1, 5}}, {"e", {5, 8}}}),
@ -496,10 +501,11 @@ TEST(NameRangesForNodeTest, NRepeats) {
"b:2, b:3)",
SummarizeNodeDef(node_def1));
const NodeDef node_def2 = ToNodeDef(NodeDefBuilder("nr", &op_def)
.Input(FakeInput(2, DT_INT32))
.Input(FakeInput(2, DT_DOUBLE))
.Attr("M", 7));
const NodeDef node_def2 =
ToNodeDef(std::move(NodeDefBuilder("nr", &op_def)
.Input(FakeInput(2, DT_INT32))
.Input(FakeInput(2, DT_DOUBLE))
.Attr("M", 7)));
TF_EXPECT_OK(NameRangesForNode(node_def2, op_def, &inputs, &outputs));
EXPECT_EQ(NameRangeMap({{"a", {0, 2}}, {"b", {2, 4}}}), inputs);
EXPECT_EQ(NameRangeMap({{"c", {0, 1}}, {"d", {1, 3}}, {"e", {3, 10}}}),
@ -524,10 +530,10 @@ TEST(NameRangesForNodeTest, TypeList) {
.Attr("T3: list(type)"));
NameRangeMap inputs, outputs;
const NodeDef node_def1 =
ToNodeDef(NodeDefBuilder("tl", &op_def)
.Input(FakeInput({DT_BOOL, DT_FLOAT}))
.Input(FakeInput(4, DT_FLOAT))
.Attr("T3", {DT_INT32, DT_DOUBLE, DT_STRING}));
ToNodeDef(std::move(NodeDefBuilder("tl", &op_def)
.Input(FakeInput({DT_BOOL, DT_FLOAT}))
.Input(FakeInput(4, DT_FLOAT))
.Attr("T3", {DT_INT32, DT_DOUBLE, DT_STRING})));
TF_EXPECT_OK(NameRangesForNode(node_def1, op_def, &inputs, &outputs));
EXPECT_EQ(NameRangeMap({{"a", {0, 2}}, {"b", {2, 6}}}), inputs);
EXPECT_EQ(NameRangeMap({{"c", {0, 4}}, {"d", {4, 7}}, {"e", {7, 9}}}),
@ -538,10 +544,11 @@ TEST(NameRangesForNodeTest, TypeList) {
" T3=[DT_INT32, DT_DOUBLE, DT_STRING]](a, a:1, b, b:1, b:2, b:3)",
SummarizeNodeDef(node_def1));
const NodeDef node_def2 = ToNodeDef(NodeDefBuilder("tl", &op_def)
.Input(FakeInput(7, DT_INT32))
.Input(FakeInput({DT_DOUBLE}))
.Attr("T3", {DT_DOUBLE, DT_STRING}));
const NodeDef node_def2 =
ToNodeDef(std::move(NodeDefBuilder("tl", &op_def)
.Input(FakeInput(7, DT_INT32))
.Input(FakeInput({DT_DOUBLE}))
.Attr("T3", {DT_DOUBLE, DT_STRING})));
TF_EXPECT_OK(NameRangesForNode(node_def2, op_def, &inputs, &outputs));
EXPECT_EQ(NameRangeMap({{"a", {0, 7}}, {"b", {7, 8}}}), inputs);
EXPECT_EQ(NameRangeMap({{"c", {0, 1}}, {"d", {1, 3}}, {"e", {3, 10}}}),

View File

@ -227,7 +227,7 @@ NodeDef* AddSend(const PartitionOptions& opts, const GraphInfo& g_info,
}
NodeDef* cast = gdef->add_node();
*status = cast_builder.Finalize(cast);
*status = cast_builder.Finalize(cast, /*consume=*/true);
if (!status->ok()) return nullptr;
// Connect the Send op to the cast.
@ -244,7 +244,7 @@ NodeDef* AddSend(const PartitionOptions& opts, const GraphInfo& g_info,
send_builder.Attr("_start_time", start_time);
}
NodeDef* send = gdef->add_node();
*status = send_builder.Finalize(send);
*status = send_builder.Finalize(send, /*consume=*/true);
return send;
}
@ -301,7 +301,7 @@ NodeDef* AddRecv(const PartitionOptions& opts, const GraphInfo& g_info,
recv_builder.Device(dst->assigned_device_name())
.Attr("tensor_type", cast_dtype);
NodeDef* recv = gdef->add_node();
*status = recv_builder.Finalize(recv);
*status = recv_builder.Finalize(recv, /*consume=*/true);
if (!status->ok()) return nullptr;
*real_recv = recv;
@ -314,7 +314,7 @@ NodeDef* AddRecv(const PartitionOptions& opts, const GraphInfo& g_info,
cast_builder.Device(dst->assigned_device_name())
.Input(recv->name(), 0, cast_dtype);
NodeDef* cast = gdef->add_node();
*status = cast_builder.Finalize(cast);
*status = cast_builder.Finalize(cast, /*consume=*/true);
if (!status->ok()) return nullptr;
return cast;
} else if (edge->IsControlEdge()) {
@ -324,7 +324,7 @@ NodeDef* AddRecv(const PartitionOptions& opts, const GraphInfo& g_info,
id_builder.Device(dst->assigned_device_name())
.Input(recv->name(), 0, cast_dtype);
NodeDef* id = gdef->add_node();
*status = id_builder.Finalize(id);
*status = id_builder.Finalize(id, /*consume=*/true);
if (!status->ok()) return nullptr;
return id;
} else {
@ -341,7 +341,7 @@ NodeDef* AddDummyConst(const PartitionOptions& opts, GraphDef* gdef,
.Device(src->assigned_device_name())
.Attr("dtype", DT_FLOAT)
.Attr("value", tensor)
.Finalize(result);
.Finalize(result, /*consume=*/true);
return result;
}
@ -354,7 +354,7 @@ NodeDef* AddControlTrigger(const PartitionOptions& opts, GraphDef* gdef,
"ControlTrigger")
.Device(assigned_device_name)
.Attr("_start_time", starttime)
.Finalize(result);
.Finalize(result, /*consume=*/true);
return result;
}
@ -424,7 +424,7 @@ Node* AddControlEnter(Graph* g, const string& node_name,
node_builder.Attr("frame_name", frame_name);
node_builder.Attr("parallel_iterations", parallel_iterations);
Node* res_node;
*status = node_builder.Finalize(g, &res_node);
*status = node_builder.Finalize(g, &res_node, /*consume=*/true);
if (!status->ok()) return nullptr;
res_node->set_assigned_device_name(device_name);
return res_node;
@ -437,7 +437,7 @@ Node* AddControlMerge(const string& in_name1, const string& in_name2, Graph* g,
NodeBuilder node_builder(node_name, "Merge", g->op_registry());
node_builder.Input({{in_name1, 0, DT_FLOAT}, {in_name2, 0, DT_FLOAT}});
Node* res_node;
*status = node_builder.Finalize(g, &res_node);
*status = node_builder.Finalize(g, &res_node, /*consume=*/true);
if (!status->ok()) return nullptr;
res_node->set_assigned_device_name(device_name);
return res_node;

View File

@ -112,7 +112,7 @@ NodeBuilder& NodeBuilder::XlaCluster(StringPiece xla_cluster) {
return *this;
}
Status NodeBuilder::Finalize(Graph* graph, Node** created_node) const {
Status NodeBuilder::Finalize(Graph* graph, Node** created_node, bool consume) {
// In case of error, set *created_node to nullptr.
if (created_node != nullptr) *created_node = nullptr;
if (!errors_.empty()) {
@ -120,7 +120,7 @@ Status NodeBuilder::Finalize(Graph* graph, Node** created_node) const {
}
NodeDef node_def;
TF_RETURN_IF_ERROR(def_builder_.Finalize(&node_def));
TF_RETURN_IF_ERROR(def_builder_.Finalize(&node_def, consume));
TF_RETURN_IF_ERROR(ValidateNodeDef(node_def, def_builder_.op_def()));
TF_RETURN_IF_ERROR(
CheckOpDeprecation(def_builder_.op_def(), graph->versions().producer()));

View File

@ -121,7 +121,9 @@ class NodeBuilder {
// Validates the described node and adds it to *graph, adding edges
// for all (non-back) inputs. If created_node is not nullptr,
// *created_node will be set to the new node (or nullptr on error).
Status Finalize(Graph* graph, Node** created_node) const;
// If `consume` is true, the builder state will be moved into `node_def`,
// and the builder will be left in an undefined state.
Status Finalize(Graph* graph, Node** created_node, bool consume = false);
// Accessors for the values set in the constructor.
const string& node_name() const { return def_builder_.node_name(); }

View File

@ -229,7 +229,7 @@ Status ArgFeedRewrite::AddNode(Graph* g, NodeBuilder::NodeOut feed_tensor,
"_Arg")
.Attr("T", BaseType(feed_tensor.node->output_type(feed_tensor.index)))
.Attr("index", arg_index_)
.Finalize(g, out_node));
.Finalize(g, out_node, /*consume=*/true));
(*out_node)->set_assigned_device_name(device_info().name());
return Status::OK();
}
@ -248,7 +248,7 @@ Status RecvFeedRewrite::AddNode(Graph* g, NodeBuilder::NodeOut feed_tensor,
.Attr("send_device_incarnation",
static_cast<int64>(device_info().incarnation()))
.Attr("client_terminated", true)
.Finalize(g, out_node));
.Finalize(g, out_node, /*consume=*/true));
(*out_node)->set_assigned_device_name(device_info().name());
return Status::OK();
@ -268,7 +268,7 @@ Status RetvalFetchRewrite::AddNode(Graph* g, NodeBuilder::NodeOut fetch_tensor,
.Attr("T",
BaseType(fetch_tensor.node->output_type(fetch_tensor.index)))
.Attr("index", retval_index_)
.Finalize(g, out_node));
.Finalize(g, out_node, /*consume=*/true));
(*out_node)->set_assigned_device_name(device_info().name());
return Status::OK();
}
@ -286,7 +286,7 @@ Status SendFetchRewrite::AddNode(Graph* g, NodeBuilder::NodeOut fetch_tensor,
.Attr("send_device_incarnation",
static_cast<int64>(device_info().incarnation()))
.Attr("client_terminated", true)
.Finalize(g, out_node));
.Finalize(g, out_node, /*consume=*/true));
(*out_node)->set_assigned_device_name(device_info().name());
return Status::OK();
}