[Grappler] In the constant folding optimizer: Directly convert Fill, ZerosLike, and OnesLike with known output shape to Const nodes in compressed format without materializing the (potentially large) tensor value by evaluating the node.
PiperOrigin-RevId: 232701020
This commit is contained in:
parent
3bfbcb2abd
commit
1362eaa98e
@ -103,6 +103,7 @@ cc_library(
|
||||
"//tensorflow/core/grappler/costs:graph_properties",
|
||||
"//tensorflow/core/grappler/utils:symbolic_shapes",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/grappler/optimizers/constant_folding.h"
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/core/framework/allocator.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
@ -716,6 +717,61 @@ Status ConstantFolding::MaterializeReductionIndices(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ConstantFolding::MaterializeConstantValuedNode(
|
||||
NodeDef* node, const GraphProperties& properties) {
|
||||
// Nodes that generate constant-valued outputs can be represented compactly in
|
||||
// compressed format, regardless of their shape.
|
||||
const std::vector<OpInfo::TensorProperties>& output_props =
|
||||
properties.GetOutputProperties(node->name());
|
||||
if (output_props.size() != 1) return Status::OK();
|
||||
const auto& output_shape = output_props[0].shape();
|
||||
if (!PartialTensorShape(output_shape).IsFullyDefined()) {
|
||||
return Status::OK();
|
||||
}
|
||||
if (IsFill(*node)) {
|
||||
const auto output_dtype = output_props[0].dtype();
|
||||
NodeDef* input_node = nullptr;
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
input_node = node_map_->GetNode(NodeName(node->input(i)));
|
||||
if (input_node == nullptr || !IsReallyConstant(*input_node)) {
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
TF_RETURN_IF_ERROR(CheckAttrExists(*input_node, "value"));
|
||||
const TensorProto& input_tensor = input_node->attr().at("value").tensor();
|
||||
// TODO(rmlarsen): Handle the case where the value is stored in
|
||||
// tensor_content.
|
||||
if (!input_tensor.tensor_content().empty()) {
|
||||
return Status::OK();
|
||||
}
|
||||
TensorProto* tensor = (*node->mutable_attr())["value"].mutable_tensor();
|
||||
// Copy the input tensor to the fill node, set the output shape, and
|
||||
// change the nodd type to Const.
|
||||
*tensor = input_tensor;
|
||||
*(tensor->mutable_tensor_shape()) = output_shape;
|
||||
(*node->mutable_attr())["dtype"].set_type(output_dtype);
|
||||
node->mutable_attr()->erase("T");
|
||||
node->mutable_attr()->erase("index_type");
|
||||
node->set_op("Const");
|
||||
for (int i = 0; i < 2; i++) {
|
||||
// Change inputs to a control inputs.
|
||||
const string ctrl_dep = AsControlDependency(node->input(i));
|
||||
node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep);
|
||||
node->set_input(i, ctrl_dep);
|
||||
}
|
||||
graph_modified_ = true;
|
||||
} else {
|
||||
double value =
|
||||
(IsZerosLike(*node) ? 0.0 : (IsOnesLike(*node) ? 1.0 : -1.0));
|
||||
bool success = false;
|
||||
if (value >= 0) {
|
||||
TF_RETURN_IF_ERROR(ReplaceOperationWithConstant(
|
||||
value, properties, output_shape, node, graph_, &success));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ConstantFolding::MaterializeConstants(
|
||||
const GraphProperties& properties) {
|
||||
const int node_count = graph_->node_size();
|
||||
@ -726,6 +782,8 @@ Status ConstantFolding::MaterializeConstants(
|
||||
TF_RETURN_IF_ERROR(MaterializeBroadcastGradientArgs(node, properties));
|
||||
} else if (IsReduction(node)) {
|
||||
TF_RETURN_IF_ERROR(MaterializeReductionIndices(&node, properties));
|
||||
} else if (IsFill(node) || IsZerosLike(node) || IsOnesLike(node)) {
|
||||
TF_RETURN_IF_ERROR(MaterializeConstantValuedNode(&node, properties));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
@ -1569,6 +1627,7 @@ Status ConstantFolding::ReplaceOperationWithConstant(
|
||||
node->set_input(i, ctrl_dep);
|
||||
}
|
||||
*success = true;
|
||||
graph_modified_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -67,8 +67,10 @@ class ConstantFolding : public GraphOptimizer {
|
||||
const GraphProperties& properties);
|
||||
Status MaterializeReductionIndices(NodeDef* node,
|
||||
const GraphProperties& properties);
|
||||
|
||||
Status MaterializeConstantValuedNode(NodeDef* node,
|
||||
const GraphProperties& properties);
|
||||
Status MaterializeConstants(const GraphProperties& properties);
|
||||
|
||||
bool IsFoldable(const NodeDef& node) const;
|
||||
|
||||
Status EvaluateNode(const NodeDef& node,
|
||||
|
@ -378,7 +378,7 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
|
||||
const string ones_name = strings::StrCat("ones", suffix);
|
||||
const string ctrl_zeros_name = strings::StrCat("^zeros", suffix);
|
||||
const string ctrl_ones_name = strings::StrCat("^ones", suffix);
|
||||
EXPECT_EQ(27, output.node_size());
|
||||
EXPECT_EQ(const_type == kFill ? 31 : 27, output.node_size());
|
||||
for (int i = 0; i < output.node_size(); ++i) {
|
||||
const NodeDef& node = output.node(i);
|
||||
const string& name = node.name();
|
||||
@ -3466,6 +3466,55 @@ TEST_F(ConstantFoldingCastConstTest, CastConstFolding) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ConstantFoldingTest, MaterializeConstantValuedNode) {
|
||||
tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
|
||||
|
||||
Output x =
|
||||
ops::Placeholder(scope.WithOpName("x"), DT_FLOAT,
|
||||
ops::Placeholder::Shape(TensorShape({1, 2, 3, 4})));
|
||||
Output ones_like = ops::OnesLike(scope.WithOpName("ones_like"), x);
|
||||
Output zeros_like = ops::ZerosLike(scope.WithOpName("zeros_like"), x);
|
||||
Output fill = ops::Fill(scope.WithOpName("fill"), {4, 3, 2, 1}, 42);
|
||||
|
||||
GrapplerItem item;
|
||||
TF_CHECK_OK(scope.ToGraphDef(&item.graph));
|
||||
item.fetch = {"ones_like", "zeros_like", "fill"};
|
||||
auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 2, 3, 4}));
|
||||
auto tensors_expected = EvaluateNodes(item.graph, item.fetch, {{"x", x_t}});
|
||||
|
||||
ConstantFolding optimizer(/*opt_level=*/RewriterConfig::AGGRESSIVE,
|
||||
/*cpu_device=*/nullptr);
|
||||
GraphDef output;
|
||||
Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
|
||||
TF_EXPECT_OK(status);
|
||||
|
||||
EXPECT_EQ(output.node_size(), 6);
|
||||
for (const auto& node : output.node()) {
|
||||
if (node.name() != "x") {
|
||||
EXPECT_EQ(node.op(), "Const");
|
||||
}
|
||||
if (node.name() == "ones_like" || node.name() == "zeros_like") {
|
||||
ASSERT_EQ(node.input_size(), 1);
|
||||
EXPECT_EQ(node.input(0), "^x");
|
||||
}
|
||||
if (node.name() == "fill") {
|
||||
ASSERT_EQ(node.input_size(), 2);
|
||||
EXPECT_EQ(node.input(0)[0], '^');
|
||||
EXPECT_EQ(node.input(1)[0], '^');
|
||||
}
|
||||
}
|
||||
auto tensors = EvaluateNodes(output, item.fetch, {{"x", x_t}});
|
||||
ASSERT_EQ(item.fetch.size(), tensors.size());
|
||||
ASSERT_EQ(tensors_expected.size(), tensors.size());
|
||||
for (int i = 0; i < tensors.size(); i++) {
|
||||
if (item.fetch[i] == "fill") {
|
||||
test::ExpectTensorEqual<int>(tensors_expected[i], tensors[i]);
|
||||
} else {
|
||||
test::ExpectTensorEqual<float>(tensors_expected[i], tensors[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
||||
|
Loading…
x
Reference in New Issue
Block a user