From 1362eaa98e3e7778446eb767473450ef484996ca Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 6 Feb 2019 10:37:05 -0800 Subject: [PATCH] [Grappler] In the constant folding optimizer: Directly convert Fill, ZerosLike, and OnesLike with known output shape to Const nodes in compressed format without materializing the (potentially large) tensor value by evaluating the node. PiperOrigin-RevId: 232701020 --- tensorflow/core/grappler/optimizers/BUILD | 1 + .../grappler/optimizers/constant_folding.cc | 59 +++++++++++++++++++ .../grappler/optimizers/constant_folding.h | 4 +- .../optimizers/constant_folding_test.cc | 51 +++++++++++++++- 4 files changed, 113 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index cb5d7d67468..cbf0d680fe7 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -103,6 +103,7 @@ cc_library( "//tensorflow/core/grappler/costs:graph_properties", "//tensorflow/core/grappler/utils:symbolic_shapes", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index da7a3280015..f883f8926f8 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/constant_folding.h" +#include "absl/strings/string_view.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/function.pb.h" @@ -716,6 +717,61 @@ Status ConstantFolding::MaterializeReductionIndices( return Status::OK(); } +Status ConstantFolding::MaterializeConstantValuedNode( + NodeDef* node, const GraphProperties& properties) { + // Nodes that generate constant-valued outputs can be represented compactly in + // compressed format, regardless of their shape. + const std::vector& output_props = + properties.GetOutputProperties(node->name()); + if (output_props.size() != 1) return Status::OK(); + const auto& output_shape = output_props[0].shape(); + if (!PartialTensorShape(output_shape).IsFullyDefined()) { + return Status::OK(); + } + if (IsFill(*node)) { + const auto output_dtype = output_props[0].dtype(); + NodeDef* input_node = nullptr; + for (int i = 0; i < 2; ++i) { + input_node = node_map_->GetNode(NodeName(node->input(i))); + if (input_node == nullptr || !IsReallyConstant(*input_node)) { + return Status::OK(); + } + } + TF_RETURN_IF_ERROR(CheckAttrExists(*input_node, "value")); + const TensorProto& input_tensor = input_node->attr().at("value").tensor(); + // TODO(rmlarsen): Handle the case where the value is stored in + // tensor_content. + if (!input_tensor.tensor_content().empty()) { + return Status::OK(); + } + TensorProto* tensor = (*node->mutable_attr())["value"].mutable_tensor(); + // Copy the input tensor to the fill node, set the output shape, and + // change the nodd type to Const. + *tensor = input_tensor; + *(tensor->mutable_tensor_shape()) = output_shape; + (*node->mutable_attr())["dtype"].set_type(output_dtype); + node->mutable_attr()->erase("T"); + node->mutable_attr()->erase("index_type"); + node->set_op("Const"); + for (int i = 0; i < 2; i++) { + // Change inputs to a control inputs. + const string ctrl_dep = AsControlDependency(node->input(i)); + node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep); + node->set_input(i, ctrl_dep); + } + graph_modified_ = true; + } else { + double value = + (IsZerosLike(*node) ? 0.0 : (IsOnesLike(*node) ? 1.0 : -1.0)); + bool success = false; + if (value >= 0) { + TF_RETURN_IF_ERROR(ReplaceOperationWithConstant( + value, properties, output_shape, node, graph_, &success)); + } + } + return Status::OK(); +} + Status ConstantFolding::MaterializeConstants( const GraphProperties& properties) { const int node_count = graph_->node_size(); @@ -726,6 +782,8 @@ Status ConstantFolding::MaterializeConstants( TF_RETURN_IF_ERROR(MaterializeBroadcastGradientArgs(node, properties)); } else if (IsReduction(node)) { TF_RETURN_IF_ERROR(MaterializeReductionIndices(&node, properties)); + } else if (IsFill(node) || IsZerosLike(node) || IsOnesLike(node)) { + TF_RETURN_IF_ERROR(MaterializeConstantValuedNode(&node, properties)); } } return Status::OK(); @@ -1569,6 +1627,7 @@ Status ConstantFolding::ReplaceOperationWithConstant( node->set_input(i, ctrl_dep); } *success = true; + graph_modified_ = true; return Status::OK(); } diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h index 4c532d7af12..7cf01b4b62c 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.h +++ b/tensorflow/core/grappler/optimizers/constant_folding.h @@ -67,8 +67,10 @@ class ConstantFolding : public GraphOptimizer { const GraphProperties& properties); Status MaterializeReductionIndices(NodeDef* node, const GraphProperties& properties); - + Status MaterializeConstantValuedNode(NodeDef* node, + const GraphProperties& properties); Status MaterializeConstants(const GraphProperties& properties); + bool IsFoldable(const NodeDef& node) const; Status EvaluateNode(const NodeDef& node, diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc index d7cabf5a8b8..81d00fa5fbf 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc @@ -378,7 +378,7 @@ TEST_F(ConstantFoldingTest, NeutralElement) { const string ones_name = strings::StrCat("ones", suffix); const string ctrl_zeros_name = strings::StrCat("^zeros", suffix); const string ctrl_ones_name = strings::StrCat("^ones", suffix); - EXPECT_EQ(27, output.node_size()); + EXPECT_EQ(const_type == kFill ? 31 : 27, output.node_size()); for (int i = 0; i < output.node_size(); ++i) { const NodeDef& node = output.node(i); const string& name = node.name(); @@ -3466,6 +3466,55 @@ TEST_F(ConstantFoldingCastConstTest, CastConstFolding) { } } +TEST_F(ConstantFoldingTest, MaterializeConstantValuedNode) { + tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); + + Output x = + ops::Placeholder(scope.WithOpName("x"), DT_FLOAT, + ops::Placeholder::Shape(TensorShape({1, 2, 3, 4}))); + Output ones_like = ops::OnesLike(scope.WithOpName("ones_like"), x); + Output zeros_like = ops::ZerosLike(scope.WithOpName("zeros_like"), x); + Output fill = ops::Fill(scope.WithOpName("fill"), {4, 3, 2, 1}, 42); + + GrapplerItem item; + TF_CHECK_OK(scope.ToGraphDef(&item.graph)); + item.fetch = {"ones_like", "zeros_like", "fill"}; + auto x_t = GenerateRandomTensor(TensorShape({1, 2, 3, 4})); + auto tensors_expected = EvaluateNodes(item.graph, item.fetch, {{"x", x_t}}); + + ConstantFolding optimizer(/*opt_level=*/RewriterConfig::AGGRESSIVE, + /*cpu_device=*/nullptr); + GraphDef output; + Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output); + TF_EXPECT_OK(status); + + EXPECT_EQ(output.node_size(), 6); + for (const auto& node : output.node()) { + if (node.name() != "x") { + EXPECT_EQ(node.op(), "Const"); + } + if (node.name() == "ones_like" || node.name() == "zeros_like") { + ASSERT_EQ(node.input_size(), 1); + EXPECT_EQ(node.input(0), "^x"); + } + if (node.name() == "fill") { + ASSERT_EQ(node.input_size(), 2); + EXPECT_EQ(node.input(0)[0], '^'); + EXPECT_EQ(node.input(1)[0], '^'); + } + } + auto tensors = EvaluateNodes(output, item.fetch, {{"x", x_t}}); + ASSERT_EQ(item.fetch.size(), tensors.size()); + ASSERT_EQ(tensors_expected.size(), tensors.size()); + for (int i = 0; i < tensors.size(); i++) { + if (item.fetch[i] == "fill") { + test::ExpectTensorEqual(tensors_expected[i], tensors[i]); + } else { + test::ExpectTensorEqual(tensors_expected[i], tensors[i]); + } + } +} + } // namespace } // namespace grappler } // namespace tensorflow