Add an option to avoid copying TensorProtos in shape inference when not needed. This saves a significant amount of memory and time in various optimizers requiring static shape inference. Calling GraphProperties::InferStatically would make at least one copy of all constant tensors stored in the graph. This CL adds a new argument to InferStatically, allowing callers to skip copying tensor values to InputProperties and OutputProperties for optimizers that only need shape and datatype information.

Turning this off for constant folding and arithmetic optimizer required cleaning up a few rewrite rules that relied on reading the copy of constant tensor values from GraphProperties instead of just extracting the original value from the corresponding constant node in the graph.

Measured results on Transformer graph:

CPU: Intel Skylake Xeon with HyperThreading (36 cores) dL1:32KB dL2:1024KB dL3:24MB
Benchmark                Time(ns)        CPU(ns) Allocs     Iterations
----------------------------------------------------------------------
Before:
BM_OptimizeTransformer 5161658741     5281083018 23532315            1   139.189MB peak-mem
After:
BM_OptimizeTransformer 4887891650     5005937073 23063225            1   132.652MB peak-mem

Effects on this graph are not dramatic, since it does not contain a lot of large constant tensors. The effect would be much more pronounced on frozen inference graphs.

PiperOrigin-RevId: 250740251
This commit is contained in:
A. Unique TensorFlower 2019-05-30 12:36:28 -07:00 committed by TensorFlower Gardener
parent d7d2307248
commit 705b193812
15 changed files with 643 additions and 663 deletions

View File

@ -712,7 +712,10 @@ class SymbolicShapeRefiner {
// Perform inference on function body. // Perform inference on function body.
GraphProperties gp(grappler_function_item); GraphProperties gp(grappler_function_item);
TF_RETURN_IF_ERROR(gp.InferStatically(true, aggressive_shape_inference_)); TF_RETURN_IF_ERROR(gp.InferStatically(
/*assume_valid_feeds=*/true,
/*aggressive_shape_inference=*/aggressive_shape_inference_,
/*include_tensor_values=*/true));
// Add return nodes for output shapes. // Add return nodes for output shapes.
int output = 0; int output = 0;
@ -2066,7 +2069,8 @@ Status GraphProperties::UpdateEnqueue(
} }
Status GraphProperties::InferStatically(bool assume_valid_feeds, Status GraphProperties::InferStatically(bool assume_valid_feeds,
bool aggressive_shape_inference) { bool aggressive_shape_inference,
bool include_tensor_values) {
FunctionLibraryDefinition function_library(OpRegistry::Global(), FunctionLibraryDefinition function_library(OpRegistry::Global(),
item_.graph.library()); item_.graph.library());
std::unordered_map<string, std::unordered_set<int>> fed_ports; std::unordered_map<string, std::unordered_set<int>> fed_ports;
@ -2225,20 +2229,23 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds,
&input_properties[i]); &input_properties[i]);
input.port_id = i; input.port_id = i;
GraphView::OutputPort fanin = graph_view.GetRegularFanin(input); GraphView::OutputPort fanin = graph_view.GetRegularFanin(input);
// Export tensor value to input_properties.value. if (include_tensor_values) {
if (IsConstant(*fanin.node)) { // Export tensor value to input_properties.value.
const TensorProto& raw_val = fanin.node->attr().at("value").tensor(); if (IsConstant(*fanin.node)) {
*input_properties[i].mutable_value() = raw_val; const TensorProto& raw_val =
} else if (ctx->input_tensor_protos.size() > i && fanin.node->attr().at("value").tensor();
ctx->input_tensor_protos[i] != nullptr) { *input_properties[i].mutable_value() = raw_val;
*input_properties[i].mutable_value() = *ctx->input_tensor_protos[i]; } else if (ctx->input_tensor_protos.size() > i &&
} else if (ic->input_tensors_as_shapes().size() > i && ctx->input_tensor_protos[i] != nullptr) {
IsShapeFullyDefinedIntegerVectorOrScalar( *input_properties[i].mutable_value() = *ctx->input_tensor_protos[i];
ic, ic->input(i), ic->input_tensors_as_shapes()[i], } else if (ic->input_tensors_as_shapes().size() > i &&
ctx->input_types[i])) { IsShapeFullyDefinedIntegerVectorOrScalar(
*input_properties[i].mutable_value() = MakeTensorProtoFromShape( ic, ic->input(i), ic->input_tensors_as_shapes()[i],
ic, ic->input(i), ic->input_tensors_as_shapes()[i], ctx->input_types[i])) {
ctx->input_types[i]); *input_properties[i].mutable_value() = MakeTensorProtoFromShape(
ic, ic->input(i), ic->input_tensors_as_shapes()[i],
ctx->input_types[i]);
}
} }
} }
} }
@ -2254,20 +2261,24 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds,
for (int i = 0; i < ic->num_outputs(); ++i) { for (int i = 0; i < ic->num_outputs(); ++i) {
shape_manager->AsTensorProperties(ic->output(i), ctx->output_types[i], shape_manager->AsTensorProperties(ic->output(i), ctx->output_types[i],
&output_properties[i]); &output_properties[i]);
// Export tensor value to output_properties.value. if (include_tensor_values) {
if (IsConstant(node)) { // Export tensor value to output_properties.value.
const TensorProto& raw_val = node.attr().at("value").tensor(); if (IsConstant(node)) {
*output_properties[i].mutable_value() = raw_val; // TODO(rmlarsen): Eliminate this copy.
} else if (ctx->output_tensor_protos.size() > i && const TensorProto& raw_val = node.attr().at("value").tensor();
ctx->output_tensor_protos[i] != nullptr) { *output_properties[i].mutable_value() = raw_val;
*output_properties[i].mutable_value() = *ctx->output_tensor_protos[i]; } else if (ctx->output_tensor_protos.size() > i &&
} else if (ctx->output_tensors_as_shapes.size() > i && ctx->output_tensor_protos[i] != nullptr) {
IsShapeFullyDefinedIntegerVectorOrScalar( *output_properties[i].mutable_value() =
ic, ic->output(i), ctx->output_tensors_as_shapes[i], *ctx->output_tensor_protos[i];
ctx->output_types[i])) { } else if (ctx->output_tensors_as_shapes.size() > i &&
*output_properties[i].mutable_value() = MakeTensorProtoFromShape( IsShapeFullyDefinedIntegerVectorOrScalar(
ic, ic->output(i), ctx->output_tensors_as_shapes[i], ic, ic->output(i), ctx->output_tensors_as_shapes[i],
ctx->output_types[i]); ctx->output_types[i])) {
*output_properties[i].mutable_value() = MakeTensorProtoFromShape(
ic, ic->output(i), ctx->output_tensors_as_shapes[i],
ctx->output_types[i]);
}
} }
} }
} }

View File

@ -89,11 +89,15 @@ class GraphProperties {
// output values when possible and does other aggressive strategies. // output values when possible and does other aggressive strategies.
// Similar to assuming_valid_feeds, this may cause incorrectness in graph // Similar to assuming_valid_feeds, this may cause incorrectness in graph
// analyses, but is useful for simulation or scheduling. // analyses, but is useful for simulation or scheduling.
// If include_values is true, the values of constant tensors will be
// included in the input and output properties.
Status InferStatically(bool assume_valid_feeds, Status InferStatically(bool assume_valid_feeds,
bool aggressive_shape_inference); bool aggressive_shape_inference,
bool include_tensor_values);
Status InferStatically(bool assume_valid_feeds) { Status InferStatically(bool assume_valid_feeds) {
return InferStatically(assume_valid_feeds, return InferStatically(assume_valid_feeds,
/*aggressive_shape_inference=*/false); /*aggressive_shape_inference=*/false,
/*include_tensor_values=*/true);
} }
// Infer the shape by running the graph on the specified cluster and recording // Infer the shape by running the graph on the specified cluster and recording
// the shapes of the processed tensors. // the shapes of the processed tensors.
@ -117,6 +121,7 @@ class GraphProperties {
const string& node_name) const; const string& node_name) const;
const std::vector<OpInfo::TensorProperties>& GetOutputProperties( const std::vector<OpInfo::TensorProperties>& GetOutputProperties(
const string& node_name) const; const string& node_name) const;
// Invalidate input/output properties for nodes modified during graph // Invalidate input/output properties for nodes modified during graph
// optimization pass, to prevent potential optimizations, based on incorrect // optimization pass, to prevent potential optimizations, based on incorrect
// shape information. // shape information.

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/core/grappler/costs/graph_properties.h" #include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/framework/scope.h"
#include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/graph_def_util.h" #include "tensorflow/core/framework/graph_def_util.h"
@ -1009,7 +1010,8 @@ TEST_F(GraphPropertiesTest, SkippingValueInferenceForLargeTensors) {
GraphProperties properties(item); GraphProperties properties(item);
TF_CHECK_OK(properties.InferStatically( TF_CHECK_OK(properties.InferStatically(
/*assume_valid_feeds=*/false, /*assume_valid_feeds=*/false,
/*aggressive_shape_inference=*/true)); /*aggressive_shape_inference=*/true,
/*include_tensor_values=*/true));
const auto out_props = properties.GetOutputProperties("fill"); const auto out_props = properties.GetOutputProperties("fill");
const OpInfo::TensorProperties out_prop0 = out_props[0]; const OpInfo::TensorProperties out_prop0 = out_props[0];
EXPECT_EQ("float: [4,4]", PropToString(out_prop0)); EXPECT_EQ("float: [4,4]", PropToString(out_prop0));
@ -1028,7 +1030,8 @@ TEST_F(GraphPropertiesTest, SkippingValueInferenceForLargeTensors) {
GraphProperties properties(item); GraphProperties properties(item);
TF_CHECK_OK(properties.InferStatically( TF_CHECK_OK(properties.InferStatically(
/*assume_valid_feeds=*/false, /*assume_valid_feeds=*/false,
/*aggressive_shape_inference=*/true)); /*aggressive_shape_inference=*/true,
/*include_tensor_values=*/true));
const auto out_props = properties.GetOutputProperties("fill"); const auto out_props = properties.GetOutputProperties("fill");
const OpInfo::TensorProperties out_prop0 = out_props[0]; const OpInfo::TensorProperties out_prop0 = out_props[0];
EXPECT_EQ("float: [1000,1000,1000,1000]", PropToString(out_prop0)); EXPECT_EQ("float: [1000,1000,1000,1000]", PropToString(out_prop0));
@ -1248,10 +1251,12 @@ TEST_F(GraphPropertiesTest, ArithmeticFunctionReturnTensorValue) {
// evaluate output value. // evaluate output value.
TF_CHECK_OK(properties.InferStatically( TF_CHECK_OK(properties.InferStatically(
/*assume_valid_feeds=*/true, /*assume_valid_feeds=*/true,
/*aggressive_shape_inference=*/false)); /*aggressive_shape_inference=*/false,
/*include_tensor_values=*/true));
const auto out_props = properties.GetOutputProperties("MyFunc"); const auto out_props = properties.GetOutputProperties("MyFunc");
const OpInfo::TensorProperties out_prop0 = out_props[0]; const OpInfo::TensorProperties out_prop0 = out_props[0];
EXPECT_EQ("int32: [2]", PropToString(out_prop0)); EXPECT_EQ("int32: [2]", PropToString(out_prop0));
LOG(INFO) << out_prop0.DebugString();
EXPECT_FALSE(out_prop0.has_value()); EXPECT_FALSE(out_prop0.has_value());
} }
@ -1260,7 +1265,8 @@ TEST_F(GraphPropertiesTest, ArithmeticFunctionReturnTensorValue) {
// With aggressive_shape_inference, output value is evaluated. // With aggressive_shape_inference, output value is evaluated.
TF_CHECK_OK(properties.InferStatically( TF_CHECK_OK(properties.InferStatically(
/*assume_valid_feeds=*/true, /*assume_valid_feeds=*/true,
/*aggressive_shape_inference=*/true)); /*aggressive_shape_inference=*/true,
/*include_tensor_values=*/true));
const auto out_props = properties.GetOutputProperties("MyFunc"); const auto out_props = properties.GetOutputProperties("MyFunc");
const OpInfo::TensorProperties out_prop0 = out_props[0]; const OpInfo::TensorProperties out_prop0 = out_props[0];
EXPECT_EQ("int32: [2]", PropToString(out_prop0)); EXPECT_EQ("int32: [2]", PropToString(out_prop0));
@ -1802,7 +1808,8 @@ TEST_F(GraphPropertiesTest, StridedSliceOfShapeWithShrinkAxisMask) {
GraphProperties properties(item); GraphProperties properties(item);
TF_CHECK_OK(properties.InferStatically( TF_CHECK_OK(properties.InferStatically(
/*assume_valid_feeds=*/false, /*assume_valid_feeds=*/false,
/*aggressive_shape_inference=*/false)); /*aggressive_shape_inference=*/false,
/*include_tensor_values=*/true));
EXPECT_FALSE(properties.GetOutputProperties("slice").at(0).has_value()); EXPECT_FALSE(properties.GetOutputProperties("slice").at(0).has_value());
} }
@ -1812,7 +1819,8 @@ TEST_F(GraphPropertiesTest, StridedSliceOfShapeWithShrinkAxisMask) {
GraphProperties properties(item); GraphProperties properties(item);
TF_CHECK_OK(properties.InferStatically( TF_CHECK_OK(properties.InferStatically(
/*assume_valid_feeds=*/false, /*assume_valid_feeds=*/false,
/*aggressive_shape_inference=*/true)); /*aggressive_shape_inference=*/true,
/*include_tensor_values=*/true));
EXPECT_TRUE(properties.GetOutputProperties("slice").at(0).has_value()); EXPECT_TRUE(properties.GetOutputProperties("slice").at(0).has_value());
const auto slice_value = const auto slice_value =
properties.GetOutputProperties("slice").at(0).value(); properties.GetOutputProperties("slice").at(0).value();
@ -1838,7 +1846,8 @@ TEST_F(GraphPropertiesTest, ValuePropagationThroughArithmeticOps) {
GraphProperties properties(item); GraphProperties properties(item);
TF_CHECK_OK(properties.InferStatically( TF_CHECK_OK(properties.InferStatically(
/*assume_valid_feeds=*/false, /*assume_valid_feeds=*/false,
/*aggressive_shape_inference=*/true)); /*aggressive_shape_inference=*/true,
/*include_tensor_values=*/true));
// Check output shapes and values. // Check output shapes and values.
const auto& a_plus_one_prop = properties.GetOutputProperties("a_plus_one")[0]; const auto& a_plus_one_prop = properties.GetOutputProperties("a_plus_one")[0];
@ -1881,7 +1890,8 @@ TEST_F(GraphPropertiesTest, ShapeAnnotation) {
// Without aggressive_shape_inference, ignore annotated information. // Without aggressive_shape_inference, ignore annotated information.
TF_CHECK_OK(properties.InferStatically( TF_CHECK_OK(properties.InferStatically(
/*assume_valid_feeds=*/false, /*assume_valid_feeds=*/false,
/*aggressive_shape_inference=*/false)); /*aggressive_shape_inference=*/false,
/*include_tensor_values=*/true));
const auto props = properties.GetOutputProperties("Identity"); const auto props = properties.GetOutputProperties("Identity");
EXPECT_EQ(1, props.size()); EXPECT_EQ(1, props.size());
const OpInfo::TensorProperties& prop = props[0]; const OpInfo::TensorProperties& prop = props[0];
@ -1895,7 +1905,8 @@ TEST_F(GraphPropertiesTest, ShapeAnnotation) {
// Use annotated information. // Use annotated information.
TF_CHECK_OK(properties.InferStatically( TF_CHECK_OK(properties.InferStatically(
/*assume_valid_feeds=*/false, /*assume_valid_feeds=*/false,
/*aggressive_shape_inference=*/true)); /*aggressive_shape_inference=*/true,
/*include_tensor_values=*/true));
const auto props = properties.GetOutputProperties("Identity"); const auto props = properties.GetOutputProperties("Identity");
EXPECT_EQ(1, props.size()); EXPECT_EQ(1, props.size());
const OpInfo::TensorProperties& prop = props[0]; const OpInfo::TensorProperties& prop = props[0];
@ -1923,7 +1934,8 @@ TEST_F(GraphPropertiesTest, ShapeAnnotationWithCompatibleShapes) {
// Use annotated information. // Use annotated information.
TF_CHECK_OK(properties.InferStatically( TF_CHECK_OK(properties.InferStatically(
/*assume_valid_feeds=*/false, /*assume_valid_feeds=*/false,
/*aggressive_shape_inference=*/true)); /*aggressive_shape_inference=*/true,
/*include_tensor_values=*/true));
const auto props = properties.GetOutputProperties("Identity"); const auto props = properties.GetOutputProperties("Identity");
EXPECT_EQ(1, props.size()); EXPECT_EQ(1, props.size());
const OpInfo::TensorProperties& prop = props[0]; const OpInfo::TensorProperties& prop = props[0];
@ -1950,7 +1962,8 @@ TEST_F(GraphPropertiesTest, ShapeAnnotationWithIncompatibleShapes) {
// Use annotated information. // Use annotated information.
TF_CHECK_OK(properties.InferStatically( TF_CHECK_OK(properties.InferStatically(
/*assume_valid_feeds=*/false, /*assume_valid_feeds=*/false,
/*aggressive_shape_inference=*/true)); /*aggressive_shape_inference=*/true,
/*include_tensor_values=*/true));
const auto props = properties.GetOutputProperties("Identity"); const auto props = properties.GetOutputProperties("Identity");
EXPECT_EQ(1, props.size()); EXPECT_EQ(1, props.size());
const OpInfo::TensorProperties& prop = props[0]; const OpInfo::TensorProperties& prop = props[0];
@ -1977,7 +1990,8 @@ TEST_F(GraphPropertiesTest, ShapeAnnotationWithoutInferenceFn) {
// Use annotated information. // Use annotated information.
TF_CHECK_OK(properties.InferStatically( TF_CHECK_OK(properties.InferStatically(
/*assume_valid_feeds=*/false, /*assume_valid_feeds=*/false,
/*aggressive_shape_inference=*/true)); /*aggressive_shape_inference=*/true,
/*include_tensor_values=*/true));
const auto props = properties.GetOutputProperties("TestOpWithNoInferenceFn"); const auto props = properties.GetOutputProperties("TestOpWithNoInferenceFn");
EXPECT_EQ(1, props.size()); EXPECT_EQ(1, props.size());
const OpInfo::TensorProperties& prop = props[0]; const OpInfo::TensorProperties& prop = props[0];

View File

@ -359,7 +359,7 @@ Status VirtualScheduler::Init(const GrapplerItem* item) {
graph_properties_ = absl::make_unique<GraphProperties>(*item); graph_properties_ = absl::make_unique<GraphProperties>(*item);
if (use_static_shapes_) { if (use_static_shapes_) {
TF_RETURN_IF_ERROR(graph_properties_->InferStatically( TF_RETURN_IF_ERROR(graph_properties_->InferStatically(
true, use_aggressive_shape_inference_)); true, use_aggressive_shape_inference_, true));
} else { } else {
TF_RETURN_IF_ERROR(graph_properties_->InferDynamically(cluster_)); TF_RETURN_IF_ERROR(graph_properties_->InferDynamically(cluster_));
} }

View File

@ -234,6 +234,14 @@ class ArithmeticOptimizerStage : public GraphOptimizerStage<string> {
DedupControlInputs(target_node); DedupControlInputs(target_node);
} }
bool IsReallyConstant(const NodeDef& node) const {
if (!IsConstant(node)) {
return false;
}
// If the node is fed it's not constant anymore.
return ctx().feed_nodes->find(node.name()) == ctx().feed_nodes->end();
}
bool IsInPreserveSet(const NodeDef& node) const { bool IsInPreserveSet(const NodeDef& node) const {
return ctx().nodes_to_preserve->find(node.name()) != return ctx().nodes_to_preserve->find(node.name()) !=
ctx().nodes_to_preserve->end(); ctx().nodes_to_preserve->end();
@ -259,6 +267,14 @@ class ArithmeticOptimizerStage : public GraphOptimizerStage<string> {
return false; return false;
} }
bool GetTensorFromConstNode(const string& node_name_or_input,
Tensor* tensor) {
const NodeDef* node = ctx().node_map->GetNode(node_name_or_input);
return node != nullptr && IsReallyConstant(*node) &&
CheckAttrExists(*node, "value").ok() &&
tensor->FromProto(node->attr().at("value").tensor());
}
private: private:
// Extended context required for ArithmeticOptimizer. // Extended context required for ArithmeticOptimizer.
const ArithmeticOptimizerContext ctx_ext_; const ArithmeticOptimizerContext ctx_ext_;
@ -2480,101 +2496,78 @@ class ConvertPowStage : public ArithmeticOptimizerStage {
bool IsSupported(const NodeDef* node) const override { bool IsSupported(const NodeDef* node) const override {
return IsPow(*node) && return IsPow(*node) &&
ctx().graph_properties->GetInputProperties(node->name()).size() == 2; ctx().graph_properties->HasOutputProperties(node->name()) &&
ctx().graph_properties->HasInputProperties(node->name());
} }
Status TrySimplify(NodeDef* node, string* simplified_node_name) override { Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
const auto& pow_props = Tensor pow;
ctx().graph_properties->GetInputProperties(node->name())[1]; if (!GetTensorFromConstNode(node->input(1), &pow)) return Status::OK();
PartialTensorShape shape(pow_props.shape()); complex128 prev, curr;
if (!shape.IsFullyDefined()) { for (int i = 0; i < pow.NumElements(); ++i) {
// skip if p is not fully defined. if (!GetElementUnexhaustive(pow, i, {pow.dtype()}, &curr)) {
return Status::OK(); // input data type is not supported by Pow. Skip.
return Status::OK();
}
if (i != 0 && curr != prev) {
// pow has different values on different elements. Skip.
return Status::OK();
}
prev = curr;
} }
if (TensorShape::IsValid(pow_props.shape()) && pow_props.has_value()) { NodeDef *x, *y;
Tensor pow(pow_props.dtype(), pow_props.shape()); TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &x));
if (!pow.FromProto(pow_props.value())) { TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &y));
return errors::InvalidArgument("Cannot parse tensor from proto: ", const auto& value_props =
pow_props.value().DebugString()); ctx().graph_properties->GetInputProperties(node->name())[0];
} const TensorShapeProto& output_shape =
ctx().graph_properties->GetOutputProperties(node->name())[0].shape();
complex128 prev, curr; if (curr == complex128(2, 0)) {
for (int i = 0; i < pow.NumElements(); ++i) { node->set_op("Square");
if (!GetElementUnexhaustive(pow, i, {pow_props.dtype()}, &curr)) { node->set_input(1, AsControlDependency(y->name()));
// input data type is not supported by Pow. Skip. AddToOptimizationQueue(node);
return Status::OK(); AddToOptimizationQueue(y);
} } else if (curr == complex128(1, 0) &&
if (i != 0 && curr != prev) { ShapesSymbolicallyEqual(value_props.shape(), output_shape)) {
// pow has different values on different elements. Skip. // Pow could be used to broadcast, so make sure the shapes of the two
return Status::OK(); // arguments are identical before replacing Pow with Identity.
} node->set_op("Identity");
prev = curr; node->set_input(1, AsControlDependency(y->name()));
} AddToOptimizationQueue(node);
NodeDef *x, *y; AddToOptimizationQueue(y);
TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &x)); } else if (curr == complex128(0.5, 0)) {
TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &y)); node->set_op("Sqrt");
const auto& value_props = node->set_input(1, AsControlDependency(y->name()));
ctx().graph_properties->GetInputProperties(node->name())[0]; AddToOptimizationQueue(node);
const TensorShapeProto& output_shape = AddToOptimizationQueue(y);
ctx().graph_properties->GetOutputProperties(node->name())[0].shape(); } else if (curr == complex128(0, 0) &&
if (curr == complex128(2, 0)) { ShapesSymbolicallyEqual(value_props.shape(), output_shape) &&
node->set_op("Square"); PartialTensorShape(output_shape).IsFullyDefined()) {
node->set_input(1, AsControlDependency(y->name())); const auto dtype = node->attr().at("T").type();
AddToOptimizationQueue(node); Tensor ones(dtype, output_shape);
AddToOptimizationQueue(y); for (int i = 0; i < ones.NumElements(); ++i) {
} else if (curr == complex128(1, 0) && TF_RETURN_IF_ERROR(SetElementToOne(i, &ones));
ShapesSymbolicallyEqual(value_props.shape(), output_shape)) {
// Pow could be used to broadcast, so make sure the shapes of the two
// arguments are identical before replacing Pow with Identity.
node->set_op("Identity");
node->set_input(1, AsControlDependency(y->name()));
AddToOptimizationQueue(node);
AddToOptimizationQueue(y);
} else if (curr == complex128(0.5, 0)) {
node->set_op("Sqrt");
node->set_input(1, AsControlDependency(y->name()));
AddToOptimizationQueue(node);
AddToOptimizationQueue(y);
} else if (curr == complex128(0, 0) &&
ShapesSymbolicallyEqual(value_props.shape(), output_shape)) {
PartialTensorShape shape(value_props.shape());
if (!shape.IsFullyDefined()) {
// skip if b is not fully defined.
return Status::OK();
}
if (TensorShape::IsValid(value_props.shape()) &&
value_props.has_value()) {
Tensor base(value_props.dtype(), value_props.shape());
if (!base.FromProto(value_props.value())) {
return errors::InvalidArgument("Cannot parse tensor from proto: ",
value_props.value().DebugString());
}
node->set_op("Const");
Tensor c(base.dtype(), base.shape());
for (int i = 0; i < c.NumElements(); ++i) {
TF_RETURN_IF_ERROR(SetElementToOne(i, &c));
}
(*node->mutable_attr())["dtype"].set_type(base.dtype());
c.AsProtoTensorContent(
(*node->mutable_attr())["value"].mutable_tensor());
node->mutable_attr()->erase("T");
node->set_input(0, AsControlDependency(x->name()));
node->set_input(1, AsControlDependency(y->name()));
AddToOptimizationQueue(node);
AddToOptimizationQueue(x);
AddToOptimizationQueue(y);
}
} else if (curr == complex128(-0.5, 0)) {
node->set_op("Rsqrt");
node->set_input(1, AsControlDependency(y->name()));
AddToOptimizationQueue(node);
AddToOptimizationQueue(y);
} else if (curr == complex128(-1, 0)) {
node->set_op("Reciprocal");
node->set_input(1, AsControlDependency(y->name()));
AddToOptimizationQueue(node);
AddToOptimizationQueue(y);
} }
node->set_op("Const");
(*node->mutable_attr())["dtype"].set_type(dtype);
node->mutable_attr()->erase("T");
ones.AsProtoTensorContent(
(*node->mutable_attr())["value"].mutable_tensor());
node->set_input(0, AsControlDependency(x->name()));
node->set_input(1, AsControlDependency(y->name()));
AddToOptimizationQueue(node);
AddToOptimizationQueue(x);
AddToOptimizationQueue(y);
} else if (curr == complex128(-0.5, 0)) {
node->set_op("Rsqrt");
node->set_input(1, AsControlDependency(y->name()));
AddToOptimizationQueue(node);
AddToOptimizationQueue(y);
} else if (curr == complex128(-1, 0)) {
node->set_op("Reciprocal");
node->set_input(1, AsControlDependency(y->name()));
AddToOptimizationQueue(node);
AddToOptimizationQueue(y);
} }
return Status::OK(); return Status::OK();
} }
@ -2638,12 +2631,12 @@ class ConvertLog1pStage : public ArithmeticOptimizerStage {
} }
private: private:
Status TrySimplifyInternal(NodeDef* node, NodeDef* input, int i, int j, Status TrySimplifyInternal(NodeDef* node, NodeDef* add_node, int i, int j,
bool* modified) { bool* modified) {
const auto& t = const auto& t =
ctx().graph_properties->GetInputProperties(input->name())[i]; ctx().graph_properties->GetInputProperties(add_node->name())[i];
const auto& c = const auto& c =
ctx().graph_properties->GetInputProperties(input->name())[j]; ctx().graph_properties->GetInputProperties(add_node->name())[j];
for (int k = 0; k < c.shape().dim_size(); ++k) { for (int k = 0; k < c.shape().dim_size(); ++k) {
// Skip if c shape is not fully determined. // Skip if c shape is not fully determined.
if (c.shape().dim(k).size() < 0) { if (c.shape().dim(k).size() < 0) {
@ -2659,13 +2652,13 @@ class ConvertLog1pStage : public ArithmeticOptimizerStage {
// broadcast. // broadcast.
return Status::OK(); return Status::OK();
} }
if (TensorShape::IsValid(c.shape()) && c.has_value()) { Tensor constant;
Tensor constant(c.dtype(), c.shape()); if (GetTensorFromConstNode(add_node->input(j), &constant)) {
if (!constant.FromProto(c.value())) {
return errors::InvalidArgument("Cannot parse tensor from proto: ",
c.value().DebugString());
}
complex128 element; complex128 element;
// TODO(rmlarsen): Refactor the more general IsOnes from
// constant_folding.cc and use it here. Perhaps also convert log(x - (-1))
// or (preferably) add a passes to canonicalize Sub(x, -1) to Add(x, 1),
// and Neg(-1) to 1.
for (int k = 0; k < constant.NumElements(); ++k) { for (int k = 0; k < constant.NumElements(); ++k) {
if (!GetElementUnexhaustive(constant, k, if (!GetElementUnexhaustive(constant, k,
{DT_BFLOAT16, DT_HALF, DT_FLOAT, DT_DOUBLE, {DT_BFLOAT16, DT_HALF, DT_FLOAT, DT_DOUBLE,
@ -2680,15 +2673,15 @@ class ConvertLog1pStage : public ArithmeticOptimizerStage {
} }
} }
NodeDef *x, *y; NodeDef *x, *y;
TF_RETURN_IF_ERROR(GetInputNode(input->input(i), &x)); TF_RETURN_IF_ERROR(GetInputNode(add_node->input(i), &x));
TF_RETURN_IF_ERROR(GetInputNode(input->input(j), &y)); TF_RETURN_IF_ERROR(GetInputNode(add_node->input(j), &y));
node->set_op("Log1p"); node->set_op("Log1p");
node->set_input(0, input->input(i)); node->set_input(0, add_node->input(i));
node->add_input(AsControlDependency(y->name())); node->add_input(AsControlDependency(y->name()));
ForwardControlDependencies(node, {input}); ForwardControlDependencies(node, {add_node});
AddToOptimizationQueue(node); AddToOptimizationQueue(node);
AddToOptimizationQueue(input); AddToOptimizationQueue(add_node);
AddToOptimizationQueue(x); AddToOptimizationQueue(x);
AddToOptimizationQueue(y); AddToOptimizationQueue(y);
*modified = true; *modified = true;
@ -2717,25 +2710,8 @@ class ConvertExpm1Stage : public ArithmeticOptimizerStage {
if (ctx().graph_properties->GetInputProperties(node->name()).size() < 2) { if (ctx().graph_properties->GetInputProperties(node->name()).size() < 2) {
return Status::OK(); return Status::OK();
} }
const auto& t = ctx().graph_properties->GetInputProperties(node->name())[0];
NodeDef* exp;
TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &exp));
if (!IsExp(*exp)) {
return Status::OK();
}
if (ctx().graph_properties->GetInputProperties(exp->name()).empty()) {
return Status::OK();
}
const auto& t = ctx().graph_properties->GetInputProperties(exp->name())[0];
const auto& c = ctx().graph_properties->GetInputProperties(node->name())[1]; const auto& c = ctx().graph_properties->GetInputProperties(node->name())[1];
for (int k = 0; k < c.shape().dim_size(); ++k) {
// Skip if c shape is not fully determined.
if (c.shape().dim(k).size() < 0) {
return Status::OK();
}
}
TensorShapeProto broadcast_shape; TensorShapeProto broadcast_shape;
if (!ShapeAfterBroadcast(t.shape(), c.shape(), &broadcast_shape)) { if (!ShapeAfterBroadcast(t.shape(), c.shape(), &broadcast_shape)) {
return Status::OK(); return Status::OK();
@ -2745,39 +2721,39 @@ class ConvertExpm1Stage : public ArithmeticOptimizerStage {
// broadcast. // broadcast.
return Status::OK(); return Status::OK();
} }
if (TensorShape::IsValid(c.shape()) && c.has_value()) { Tensor constant;
Tensor constant(c.dtype(), c.shape()); if (!GetTensorFromConstNode(node->input(1), &constant)) return Status::OK();
if (!constant.FromProto(c.value())) { // TODO(rmlarsen): Use the more general IsOnes helper here.
return errors::InvalidArgument("Cannot parse tensor from proto: ", complex128 element;
c.value().DebugString()); for (int k = 0; k < constant.NumElements(); ++k) {
if (!GetElementUnexhaustive(constant, k,
{DT_BFLOAT16, DT_HALF, DT_FLOAT, DT_DOUBLE,
DT_COMPLEX64, DT_COMPLEX128},
&element)) {
// input data type is not supported by expm1. Skip.
return Status::OK();
} }
complex128 element; LOG(INFO) << "Got element = " << element;
for (int k = 0; k < constant.NumElements(); ++k) { if (element != complex128(1)) {
if (!GetElementUnexhaustive(constant, k, // current element is not 1. Skip.
{DT_BFLOAT16, DT_HALF, DT_FLOAT, DT_DOUBLE, return Status::OK();
DT_COMPLEX64, DT_COMPLEX128},
&element)) {
// input data type is not supported by expm1. Skip.
return Status::OK();
}
if (element != complex128(1)) {
// current element is not 1. Skip.
return Status::OK();
}
} }
NodeDef *exp_input, *ones;
TF_RETURN_IF_ERROR(GetInputNode(exp->input(0), &exp_input));
TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &ones));
node->set_op("Expm1");
node->set_input(0, exp->input(0));
node->set_input(1, AsControlDependency(ones->name()));
ForwardControlDependencies(node, {exp});
AddToOptimizationQueue(node);
AddToOptimizationQueue(exp);
AddToOptimizationQueue(exp_input);
AddToOptimizationQueue(ones);
} }
NodeDef* exp;
TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &exp));
NodeDef *exp_input, *ones;
TF_RETURN_IF_ERROR(GetInputNode(exp->input(0), &exp_input));
TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &ones));
node->set_op("Expm1");
node->set_input(0, exp->input(0));
node->set_input(1, AsControlDependency(ones->name()));
ForwardControlDependencies(node, {exp});
AddToOptimizationQueue(node);
AddToOptimizationQueue(exp);
AddToOptimizationQueue(exp_input);
AddToOptimizationQueue(ones);
*simplified_node_name = node->name();
return Status::OK(); return Status::OK();
} }
}; };
@ -3096,14 +3072,6 @@ class RemoveStackStridedSliceSameAxis : public ArithmeticOptimizerStage {
} }
protected: protected:
bool IsReallyConstant(const NodeDef& node) const {
if (!IsConstant(node)) {
return false;
}
// If the node is fed it's not constant anymore.
return ctx().feed_nodes->find(node.name()) == ctx().feed_nodes->end();
}
bool GetConstantAsInt64(const NodeDef& node, DataType dtype, bool GetConstantAsInt64(const NodeDef& node, DataType dtype,
std::vector<int64>* values) { std::vector<int64>* values) {
if (dtype == DT_INT32) { if (dtype == DT_INT32) {
@ -3430,8 +3398,6 @@ bool ArithmeticOptimizer::CanDedup(const NodeDef& node) const {
void ArithmeticOptimizer::DedupComputations() { void ArithmeticOptimizer::DedupComputations() {
CanonicalizeGraph(optimized_graph_); CanonicalizeGraph(optimized_graph_);
// LOG(INFO) << "Graph after canonicalization: \n"
// << optimized_graph_->DebugString();
GraphTopologyView graph_view; GraphTopologyView graph_view;
if (!graph_view.InitializeFromGraph(*optimized_graph_).ok()) { if (!graph_view.InitializeFromGraph(*optimized_graph_).ok()) {
@ -3683,7 +3649,10 @@ Status ArithmeticOptimizer::Optimize(Cluster* /*cluster*/,
graph_properties_.reset(new GraphProperties(optimized_item)); graph_properties_.reset(new GraphProperties(optimized_item));
const bool assume_valid_feeds = opt_level_ == RewriterConfig::AGGRESSIVE; const bool assume_valid_feeds = opt_level_ == RewriterConfig::AGGRESSIVE;
const Status status = graph_properties_->InferStatically(assume_valid_feeds); const Status status =
graph_properties_->InferStatically(assume_valid_feeds,
/*aggressive_shape_inference=*/false,
/*include_tensor_values=*/false);
const bool can_use_shapes = status.ok(); const bool can_use_shapes = status.ok();
if (!can_use_shapes) { if (!can_use_shapes) {
VLOG(1) << "Shape inference failed." << status.error_message(); VLOG(1) << "Shape inference failed." << status.error_message();

View File

@ -122,27 +122,6 @@ bool MaybeRemoveControlInput(const string& old_input, NodeDef* node,
return removed_input; return removed_input;
} }
bool GetConcatAxis(const GraphProperties& properties, NodeDef* node,
int* axis) {
if (node->op() != "ConcatV2" ||
properties.GetInputProperties(node->name()).empty()) {
return false;
}
const auto& axis_input = properties.GetInputProperties(node->name()).back();
if (!TensorShape::IsValid(axis_input.shape()) || !axis_input.has_value()) {
return false;
}
Tensor axis_tensor(axis_input.dtype(), axis_input.shape());
if (!axis_tensor.FromProto(axis_input.value())) {
return false;
}
*axis = axis_input.dtype() == DT_INT64
? static_cast<int>(axis_tensor.scalar<int64>()())
: axis_tensor.scalar<int32>()();
return true;
}
bool HasTPUAttributes(const NodeDef& node) { bool HasTPUAttributes(const NodeDef& node) {
AttrSlice attrs(node); AttrSlice attrs(node);
for (auto attr : attrs) { for (auto attr : attrs) {
@ -220,9 +199,9 @@ string ConstantFolding::AddControlDependency(const string& input_name,
if (IsControlInput(input_name)) { if (IsControlInput(input_name)) {
return input_name; return input_name;
} }
const NodeDef* node = node_map->GetNode(input_name); const NodeDef& node = *node_map->GetNode(input_name);
if (!IsSwitch(*node)) { if (!IsSwitch(node)) {
return AsControlDependency(*node); return AsControlDependency(node);
} else { } else {
// We can't anchor control dependencies directly on the switch node: unlike // We can't anchor control dependencies directly on the switch node: unlike
// other nodes only one of the outputs of the switch node will be generated // other nodes only one of the outputs of the switch node will be generated
@ -230,10 +209,10 @@ string ConstantFolding::AddControlDependency(const string& input_name,
// dependency is only triggered when the corresponding output is triggered. // dependency is only triggered when the corresponding output is triggered.
// We start by looking for an identity node connected to the output of the // We start by looking for an identity node connected to the output of the
// switch node, and use it to anchor the control dependency. // switch node, and use it to anchor the control dependency.
auto outputs = node_map->GetOutputs(node->name()); auto outputs = node_map->GetOutputs(node.name());
for (const NodeDef* output : outputs) { for (const NodeDef* output : outputs) {
if (IsIdentity(*output) || IsIdentityNSingleInput(*output)) { if (IsIdentity(*output) || IsIdentityNSingleInput(*output)) {
if (IsSameInput(node->input(0), input_name)) { if (IsSameInput(node.input(0), input_name)) {
return AsControlDependency(*output); return AsControlDependency(*output);
} }
} }
@ -244,19 +223,19 @@ string ConstantFolding::AddControlDependency(const string& input_name,
string ctrl_dep_name = ParseNodeName(input_name, &port); string ctrl_dep_name = ParseNodeName(input_name, &port);
strings::StrAppend(&ctrl_dep_name, "_", port); strings::StrAppend(&ctrl_dep_name, "_", port);
ctrl_dep_name = AddPrefixToNodeName(ctrl_dep_name, kConstantFoldingCtrl); ctrl_dep_name = AddPrefixToNodeName(ctrl_dep_name, kConstantFoldingCtrl);
const DataType output_type = node->attr().at("T").type(); const DataType output_type = node.attr().at("T").type();
NodeDef* added_node = node_map->GetNode(ctrl_dep_name); NodeDef* added_node = node_map->GetNode(ctrl_dep_name);
if (added_node == nullptr) { if (added_node == nullptr) {
added_node = graph->add_node(); added_node = graph->add_node();
added_node->set_name(ctrl_dep_name); added_node->set_name(ctrl_dep_name);
added_node->set_op("Identity"); added_node->set_op("Identity");
added_node->set_device(node->device()); added_node->set_device(node.device());
(*added_node->mutable_attr())["T"].set_type(output_type); (*added_node->mutable_attr())["T"].set_type(output_type);
*added_node->add_input() = input_name; *added_node->add_input() = input_name;
node_map->AddNode(added_node->name(), added_node); node_map->AddNode(added_node->name(), added_node);
node_map->AddOutput(node->name(), added_node->name()); node_map->AddOutput(node.name(), added_node->name());
} }
return AsControlDependency(*added_node); return AsControlDependency(*added_node);
} }
@ -321,6 +300,15 @@ bool ConstantFolding::IsReallyConstant(const NodeDef& node) const {
return feed_nodes_.find(node.name()) == feed_nodes_.end(); return feed_nodes_.find(node.name()) == feed_nodes_.end();
} }
// TODO(rmlarsen): Refactor to shared util.
bool ConstantFolding::GetTensorFromConstNode(const string& node_name_or_input,
Tensor* tensor) {
const NodeDef* node = node_map_->GetNode(node_name_or_input);
return node != nullptr && IsReallyConstant(*node) &&
CheckAttrExists(*node, "value").ok() &&
tensor->FromProto(node->attr().at("value").tensor());
}
// Materialize the shapes using constants whenever possible. // Materialize the shapes using constants whenever possible.
Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) { Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) {
// We may add some nodes to the graph to encode control dependencies and hold // We may add some nodes to the graph to encode control dependencies and hold
@ -868,6 +856,9 @@ bool ConstantFolding::IsFoldable(const NodeDef& node) const {
return false; return false;
} }
if (node.op() == "AccumulateNV2") {
return false;
}
// Skips ops that don't benefit from folding. // Skips ops that don't benefit from folding.
if (IsPlaceholder(node)) { if (IsPlaceholder(node)) {
return false; return false;
@ -1856,9 +1847,9 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node,
SET_AND_RETURN_IF_MODIFIED( SET_AND_RETURN_IF_MODIFIED(
PartialAssocOpConstFolding(optimized_graph, properties, node)); PartialAssocOpConstFolding(optimized_graph, properties, node));
SET_AND_RETURN_IF_MODIFIED( SET_AND_RETURN_IF_MODIFIED(
PartialConcatConstFolding(optimized_graph, properties, node)); MergeConcat(use_shape_info, optimized_graph, node));
SET_AND_RETURN_IF_MODIFIED( SET_AND_RETURN_IF_MODIFIED(
MergeConcat(*properties, use_shape_info, optimized_graph, node)); PartialConcatConstFolding(optimized_graph, properties, node));
graph_modified_ = graph_modified_cached; graph_modified_ = graph_modified_cached;
return Status::OK(); return Status::OK();
@ -1879,43 +1870,33 @@ void ConstantFolding::RemoveSplitOrSplitV(const GraphProperties& properties,
Status ConstantFolding::RemoveShuffleOrTranspose( Status ConstantFolding::RemoveShuffleOrTranspose(
const GraphProperties& properties, bool use_shape_info, const GraphProperties& properties, bool use_shape_info,
GraphDef* optimized_graph, NodeDef* node) { GraphDef* optimized_graph, NodeDef* node) {
if (use_shape_info && (IsShuffle(*node) || IsTranspose(*node)) && if (!use_shape_info || !(IsShuffle(*node) || IsTranspose(*node)))
properties.GetInputProperties(node->name()).size() >= 2) { return Status::OK();
Tensor permutation_tensor;
if (GetTensorFromConstNode(node->input(1), &permutation_tensor) &&
properties.HasInputProperties(node->name())) {
const auto& shape = properties.GetInputProperties(node->name())[0].shape(); const auto& shape = properties.GetInputProperties(node->name())[0].shape();
if (shape.unknown_rank()) { std::vector<int> permutation;
// Not optimizable. for (int j = 0; j < permutation_tensor.NumElements(); ++j) {
if (permutation_tensor.dtype() == DT_INT64) {
permutation.push_back(permutation_tensor.vec<int64>()(j));
} else {
permutation.push_back(permutation_tensor.vec<int>()(j));
}
}
if (permutation.size() != shape.dim_size()) {
// Number of elements in perm should be same as dim_size. Skip if not.
return Status::OK(); return Status::OK();
} }
const auto& p = properties.GetInputProperties(node->name())[1]; // The node is replaceable iff
if (TensorShape::IsValid(p.shape()) && p.has_value()) { // dim_size == 0 || all dims have size 1 ||
Tensor perm(p.dtype(), p.shape()); // all dims with > 1 size are not permuted.
if (!perm.FromProto(p.value())) { bool replaceable = true;
return errors::InvalidArgument("Cannot parse tensor from proto: ", for (int j = 0; replaceable && j < shape.dim_size(); ++j) {
p.value().DebugString()); replaceable &= shape.dim(j).size() == 1 || j == permutation[j];
} }
std::vector<int> permutation; if (replaceable) {
for (int j = 0; j < perm.NumElements(); ++j) { ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
if (perm.dtype() == DT_INT64) {
permutation.push_back(perm.vec<int64>()(j));
} else {
permutation.push_back(perm.vec<int>()(j));
}
}
if (permutation.size() != shape.dim_size()) {
// Number of elements in perm should be same as dim_size. Skip if not.
return Status::OK();
}
// The node is replaceable iff
// dim_size == 0 || all dims have size 1 ||
// all dims with > 1 size are not permuted.
bool replaceable = true;
for (int j = 0; replaceable && j < shape.dim_size(); ++j) {
replaceable &= shape.dim(j).size() == 1 || j == permutation[j];
}
if (replaceable) {
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
return Status::OK();
}
} }
} }
return Status::OK(); return Status::OK();
@ -1941,44 +1922,35 @@ Status ConstantFolding::RemoveReverse(const GraphProperties& properties,
bool use_shape_info, bool use_shape_info,
GraphDef* optimized_graph, GraphDef* optimized_graph,
NodeDef* node) { NodeDef* node) {
if (use_shape_info && node->op() == "ReverseV2" && if (!use_shape_info || node->op() != "ReverseV2") return Status::OK();
properties.GetInputProperties(node->name()).size() >= 2) { Tensor axis;
if (properties.HasInputProperties(node->name()) &&
GetTensorFromConstNode(node->input(1), &axis)) {
const auto& shape = properties.GetInputProperties(node->name())[0].shape(); const auto& shape = properties.GetInputProperties(node->name())[0].shape();
if (shape.unknown_rank()) { if (shape.unknown_rank()) return Status::OK();
// Not optimizable. std::set<int> target_axes;
return Status::OK(); for (int j = 0; j < axis.NumElements(); ++j) {
// value of axis can be negative.
if (axis.dtype() == DT_INT64) {
target_axes.insert((axis.vec<int64>()(j) + shape.dim_size()) %
shape.dim_size());
} else {
target_axes.insert((axis.vec<int>()(j) + shape.dim_size()) %
shape.dim_size());
}
} }
const auto& a = properties.GetInputProperties(node->name())[1];
if (TensorShape::IsValid(a.shape()) && a.has_value()) {
Tensor axis(a.dtype(), a.shape());
if (!axis.FromProto(a.value())) {
return errors::InvalidArgument("Cannot parse tensor from proto: ",
a.value().DebugString());
}
std::set<int> target_axes;
for (int j = 0; j < axis.NumElements(); ++j) {
// value of axis can be negative.
if (axis.dtype() == DT_INT64) {
target_axes.insert((axis.vec<int64>()(j) + shape.dim_size()) %
shape.dim_size());
} else {
target_axes.insert((axis.vec<int>()(j) + shape.dim_size()) %
shape.dim_size());
}
}
// The node is replaceable iff // The node is replaceable iff
// unknown_rank == false && // unknown_rank == false &&
// (dim_size == 0 || all dims have size 1 || // (dim_size == 0 || all dims have size 1 ||
// all dims with > 1 size are not in target_axes) // all dims with > 1 size are not in target_axes)
bool replaceable = !shape.unknown_rank(); bool replaceable = true;
for (int j = 0; replaceable && j < shape.dim_size(); ++j) { for (int j = 0; replaceable && j < shape.dim_size(); ++j) {
replaceable &= shape.dim(j).size() == 1 || replaceable &=
target_axes.find(j) == target_axes.end(); shape.dim(j).size() == 1 || target_axes.find(j) == target_axes.end();
} }
if (replaceable) { if (replaceable) {
ReplaceOperationWithIdentity(0, properties, node, optimized_graph); ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
}
} }
} }
return Status::OK(); return Status::OK();
@ -1988,45 +1960,33 @@ Status ConstantFolding::SimplifySlice(const GraphProperties& properties,
bool use_shape_info, bool use_shape_info,
GraphDef* optimized_graph, GraphDef* optimized_graph,
NodeDef* node) { NodeDef* node) {
if (use_shape_info && IsSlice(*node) && if (!use_shape_info || !IsSlice(*node)) return Status::OK();
properties.GetInputProperties(node->name()).size() == 3) { Tensor begin;
Tensor size;
if (properties.HasInputProperties(node->name()) &&
GetTensorFromConstNode(node->input(1), &begin) &&
GetTensorFromConstNode(node->input(2), &size)) {
const auto& input = properties.GetInputProperties(node->name())[0]; const auto& input = properties.GetInputProperties(node->name())[0];
const auto& b = properties.GetInputProperties(node->name())[1]; // The node is replaceable iff unknown_rank == false &&
const auto& s = properties.GetInputProperties(node->name())[2]; // begin == 0 && (size == -1 || size == input_shape) for all dimensions
if (TensorShape::IsValid(b.shape()) && b.has_value() && bool replaceable = !input.shape().unknown_rank();
TensorShape::IsValid(s.shape()) && s.has_value()) { for (int j = 0; replaceable && j < input.shape().dim_size(); ++j) {
Tensor begin(b.dtype(), b.shape()); if (begin.dtype() == DT_INT32) {
if (!begin.FromProto(b.value())) { replaceable &= begin.vec<int>()(j) == 0;
return errors::InvalidArgument("Cannot parse tensor from proto: ", } else {
b.value().DebugString()); replaceable &= begin.vec<int64>()(j) == 0;
} }
Tensor size(s.dtype(), s.shape()); if (size.dtype() == DT_INT32) {
if (!size.FromProto(s.value())) { replaceable &= (size.vec<int>()(j) == -1 ||
return errors::InvalidArgument("Cannot parse tensor from proto: ", size.vec<int>()(j) == input.shape().dim(j).size());
s.value().DebugString()); } else {
} replaceable &= (size.vec<int64>()(j) == -1 ||
// The node is replaceable iff unknown_rank == false && size.vec<int64>()(j) == input.shape().dim(j).size());
// begin == 0 && (size == -1 || size == input_shape) for all dimensions
bool replaceable = !input.shape().unknown_rank();
for (int j = 0; replaceable && j < input.shape().dim_size(); ++j) {
if (begin.dtype() == DT_INT32) {
replaceable &= begin.vec<int>()(j) == 0;
} else {
replaceable &= begin.vec<int64>()(j) == 0;
}
if (size.dtype() == DT_INT32) {
replaceable &= (size.vec<int>()(j) == -1 ||
size.vec<int>()(j) == input.shape().dim(j).size());
} else {
replaceable &= (size.vec<int64>()(j) == -1 ||
size.vec<int64>()(j) == input.shape().dim(j).size());
}
}
if (replaceable) {
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
return Status::OK();
} }
} }
if (replaceable) {
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
}
} }
return Status::OK(); return Status::OK();
} }
@ -2052,81 +2012,70 @@ Status ConstantFolding::SimplifyStridedSlice(const GraphProperties& properties,
return Status::OK(); return Status::OK();
} }
} }
const auto& b = properties.GetInputProperties(node->name())[1];
const auto& e = properties.GetInputProperties(node->name())[2];
const auto& s = properties.GetInputProperties(node->name())[3];
if (TensorShape::IsValid(b.shape()) && b.has_value() &&
TensorShape::IsValid(e.shape()) && e.has_value() &&
TensorShape::IsValid(s.shape()) && s.has_value()) {
Tensor begin(b.dtype(), b.shape());
if (!begin.FromProto(b.value())) {
return errors::InvalidArgument("Cannot parse tensor from proto: ",
b.value().DebugString());
}
Tensor end(e.dtype(), e.shape());
if (!end.FromProto(e.value())) {
return errors::InvalidArgument("Cannot parse tensor from proto: ",
e.value().DebugString());
}
Tensor strides(s.dtype(), s.shape());
if (!strides.FromProto(s.value())) {
return errors::InvalidArgument("Cannot parse tensor from proto: ",
s.value().DebugString());
}
TF_RETURN_IF_ERROR(
CheckAttrsExist(*node, {"begin_mask", "end_mask", "ellipsis_mask"}));
int begin_mask = node->attr().at("begin_mask").i();
int end_mask = node->attr().at("end_mask").i();
std::set<int> expanded_ellipsis_indices;
int ellipsis_index = -1;
for (int j = 0; j < input.shape().dim_size(); ++j) {
// find the ellipsis_mask. If not found, insert one in the end if
// necessary.
if (node->attr().at("ellipsis_mask").i() & 1 << j ||
(ellipsis_index == -1 && j >= strides.NumElements())) {
ellipsis_index = j;
}
// insert the indices that are immediately after ellipsis_index if
// necessary.
if (ellipsis_index != -1 &&
input.shape().dim_size() >
strides.NumElements() + j - ellipsis_index) {
expanded_ellipsis_indices.insert(j);
}
}
// The node is replaceable iff unknown_rank == false && std::vector<Tensor> input_tensors(3);
// ((begin_mask is set || begin == 0) && (end_mask is set || end == dim) for (int i = 1; i < 4; ++i) {
// && strides == 1) for all dimensions. if (!GetTensorFromConstNode(node->input(i), &input_tensors[i - 1])) {
bool replaceable = !input.shape().unknown_rank(); return Status::OK();
for (int j = 0; replaceable && j < input.shape().dim_size(); ++j) {
if (expanded_ellipsis_indices.find(j) !=
expanded_ellipsis_indices.end()) {
// ellipsis_mask is effective on current dimension.
continue;
}
// when we have ellipsis_mask in between, input.shape().dim_size() will
// be greater than strides.NumElements(), since we will insert
// as many as expanded_ellipsis_indices.size() axes during computation.
// We need to subtract this number from j.
int i = j;
if (ellipsis_index != -1 &&
j >= ellipsis_index + expanded_ellipsis_indices.size()) {
i = j - expanded_ellipsis_indices.size();
}
int b = begin.dtype() == DT_INT32 ? begin.vec<int>()(i)
: begin.vec<int64>()(i);
int e =
end.dtype() == DT_INT32 ? end.vec<int>()(i) : end.vec<int64>()(i);
int s = strides.dtype() == DT_INT32 ? strides.vec<int>()(i)
: strides.vec<int64>()(i);
replaceable &=
(begin_mask & 1 << i || b == 0) &&
(end_mask & 1 << i || e == input.shape().dim(j).size()) && s == 1;
} }
if (replaceable) { }
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
const Tensor& begin = input_tensors[0];
const Tensor& end = input_tensors[1];
const Tensor& strides = input_tensors[2];
TF_RETURN_IF_ERROR(
CheckAttrsExist(*node, {"begin_mask", "end_mask", "ellipsis_mask"}));
int begin_mask = node->attr().at("begin_mask").i();
int end_mask = node->attr().at("end_mask").i();
std::set<int> expanded_ellipsis_indices;
int ellipsis_index = -1;
for (int j = 0; j < input.shape().dim_size(); ++j) {
// find the ellipsis_mask. If not found, insert one in the end if
// necessary.
if (node->attr().at("ellipsis_mask").i() & 1 << j ||
(ellipsis_index == -1 && j >= strides.NumElements())) {
ellipsis_index = j;
} }
// insert the indices that are immediately after ellipsis_index if
// necessary.
if (ellipsis_index != -1 &&
input.shape().dim_size() >
strides.NumElements() + j - ellipsis_index) {
expanded_ellipsis_indices.insert(j);
}
}
// The node is replaceable iff unknown_rank == false &&
// ((begin_mask is set || begin == 0) && (end_mask is set || end == dim)
// && strides == 1) for all dimensions.
bool replaceable = !input.shape().unknown_rank();
for (int j = 0; replaceable && j < input.shape().dim_size(); ++j) {
if (expanded_ellipsis_indices.find(j) !=
expanded_ellipsis_indices.end()) {
// ellipsis_mask is effective on current dimension.
continue;
}
// when we have ellipsis_mask in between, input.shape().dim_size() will
// be greater than strides.NumElements(), since we will insert
// as many as expanded_ellipsis_indices.size() axes during computation.
// We need to subtract this number from j.
int i = j;
if (ellipsis_index != -1 &&
j >= ellipsis_index + expanded_ellipsis_indices.size()) {
i = j - expanded_ellipsis_indices.size();
}
int b = begin.dtype() == DT_INT32 ? begin.vec<int>()(i)
: begin.vec<int64>()(i);
int e = end.dtype() == DT_INT32 ? end.vec<int>()(i) : end.vec<int64>()(i);
int s = strides.dtype() == DT_INT32 ? strides.vec<int>()(i)
: strides.vec<int64>()(i);
replaceable &= (begin_mask & 1 << i || b == 0) &&
(end_mask & 1 << i || e == input.shape().dim(j).size()) &&
s == 1;
}
if (replaceable) {
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
} }
} }
return Status::OK(); return Status::OK();
@ -2135,31 +2084,23 @@ Status ConstantFolding::SimplifyStridedSlice(const GraphProperties& properties,
Status ConstantFolding::SimplifyTile(const GraphProperties& properties, Status ConstantFolding::SimplifyTile(const GraphProperties& properties,
bool use_shape_info, bool use_shape_info,
GraphDef* optimized_graph, NodeDef* node) { GraphDef* optimized_graph, NodeDef* node) {
Tensor multiplies;
if (use_shape_info && IsTile(*node) && if (use_shape_info && IsTile(*node) &&
properties.GetInputProperties(node->name()).size() == 2) { GetTensorFromConstNode(node->input(1), &multiplies)) {
const auto& m = properties.GetInputProperties(node->name())[1]; // The node is replaceable iff all values in multiplies are 1.
if (TensorShape::IsValid(m.shape()) && m.has_value()) { bool replaceable = true;
Tensor multiplies(m.dtype(), m.shape()); if (multiplies.dtype() == DT_INT32) {
if (!multiplies.FromProto(m.value())) { for (int j = 0; replaceable && j < multiplies.vec<int>().size(); ++j) {
return errors::InvalidArgument("Cannot parse tensor from proto: ", replaceable &= multiplies.vec<int>()(j) == 1;
m.value().DebugString());
} }
// The node is replaceable iff all values in multiplies are 1. } else {
bool replaceable = true; for (int j = 0; replaceable && j < multiplies.vec<int64>().size(); ++j) {
if (multiplies.dtype() == DT_INT32) { replaceable &= multiplies.vec<int64>()(j) == 1;
for (int j = 0; replaceable && j < multiplies.vec<int>().size(); ++j) {
replaceable &= multiplies.vec<int>()(j) == 1;
}
} else {
for (int j = 0; replaceable && j < multiplies.vec<int64>().size();
++j) {
replaceable &= multiplies.vec<int64>()(j) == 1;
}
}
if (replaceable) {
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
} }
} }
if (replaceable) {
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
}
} }
return Status::OK(); return Status::OK();
} }
@ -2167,26 +2108,20 @@ Status ConstantFolding::SimplifyTile(const GraphProperties& properties,
Status ConstantFolding::SimplifyPad(const GraphProperties& properties, Status ConstantFolding::SimplifyPad(const GraphProperties& properties,
bool use_shape_info, bool use_shape_info,
GraphDef* optimized_graph, NodeDef* node) { GraphDef* optimized_graph, NodeDef* node) {
if (use_shape_info && IsPad(*node) && if (!use_shape_info || !IsPad(*node)) return Status::OK();
properties.GetInputProperties(node->name()).size() >= 2) {
const auto& p = properties.GetInputProperties(node->name())[1]; Tensor paddings;
if (TensorShape::IsValid(p.shape()) && p.has_value()) { if (GetTensorFromConstNode(node->input(1), &paddings)) {
Tensor paddings(p.dtype(), p.shape()); // The node is replaceable iff all values in paddings are 0.
if (!paddings.FromProto(p.value())) { bool replaceable = true;
return errors::InvalidArgument("Cannot parse tensor from proto: ", // The operation requires it to be int32 value so we don't check for
p.value().DebugString()); // 1nt64.
} const auto flatten = paddings.flat<int32>();
// The node is replaceable iff all values in paddings are 0. for (int j = 0; replaceable && j < flatten.size(); ++j) {
bool replaceable = true; replaceable &= flatten(j) == 0;
// The operation requires it to be int32 value so we don't check for }
// 1nt64. if (replaceable) {
const auto flatten = paddings.flat<int32>(); ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
for (int j = 0; replaceable && j < flatten.size(); ++j) {
replaceable &= flatten(j) == 0;
}
if (replaceable) {
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
}
} }
} }
return Status::OK(); return Status::OK();
@ -3031,73 +2966,71 @@ bool ConstantFolding::PartialAssocOpConstFolding(GraphDef* optimized_graph,
// folding of ops when more than one but not all inputs are constant. // folding of ops when more than one but not all inputs are constant.
// For AddN and AccumulateNV2, we may furthermore reorder inputs, since // For AddN and AccumulateNV2, we may furthermore reorder inputs, since
// addition is commutative. // addition is commutative.
const int num_non_control_inputs = NumNonControlInputs(*node); if (!IsAggregate(*node) || !IsCommutative(*node)) return false;
if (IsAggregate(*node) && IsCommutative(*node) &&
num_non_control_inputs > 2) {
const int num_control_inputs = node->input_size() - num_non_control_inputs;
std::vector<int> const_inputs;
std::vector<int> nonconst_inputs;
for (int i = 0; i < node->input_size(); ++i) {
const string& input = node->input(i);
const NodeDef* input_node = node_map_->GetNode(NodeName(input));
CHECK(input_node != nullptr) << input;
if (!IsControlInput(input) && IsReallyConstant(*input_node)) {
const_inputs.push_back(i);
} else {
// Non-const and control inputs.
nonconst_inputs.push_back(i);
}
}
// Promote AccumulateNV2 with all constant inputs to AddN, since it is
// a fake node that cannot be constant folded by itself.
if (const_inputs.size() == num_non_control_inputs &&
node->op() == "AccumulateNV2") {
node->set_op("AddN");
node->mutable_attr()->erase("shape");
return true;
}
const string new_node_name = OptimizedNodeName(
*node, strings::StrCat("_partial_split_", const_inputs.size()));
if (1 < const_inputs.size() &&
const_inputs.size() < num_non_control_inputs &&
!node_map_->NodeExists(new_node_name)) {
NodeDef* added_node = optimized_graph->add_node();
*added_node = *node;
// Always use AddN for the constant node, since AccumulateNV2 is a fake
// node that cannot be constant folded, since it does not have a kernel.
added_node->set_op("AddN");
added_node->mutable_attr()->erase("shape");
added_node->set_name(new_node_name);
node_map_->AddNode(added_node->name(), added_node);
added_node->clear_input();
for (int i : const_inputs) {
added_node->add_input(node->input(i));
node_map_->UpdateOutput(NodeName(node->input(i)), node->name(),
added_node->name());
}
// Overwrite the first const input with the added node. const int num_non_control_inputs = NumNonControlInputs(*node);
node->set_input(const_inputs[0], added_node->name()); if (num_non_control_inputs <= 2) return false;
node_map_->AddOutput(added_node->name(), node->name()); const int num_control_inputs = node->input_size() - num_non_control_inputs;
nonconst_inputs.push_back(const_inputs[0]); std::vector<int> const_inputs;
// Compact the remaining inputs to the original node. std::vector<int> nonconst_inputs;
std::sort(nonconst_inputs.begin(), nonconst_inputs.end()); for (int i = 0; i < node->input_size(); ++i) {
int idx = 0; const string& input = node->input(i);
for (int i : nonconst_inputs) { const NodeDef* input_node = node_map_->GetNode(NodeName(input));
if (idx != i) { if (input_node == nullptr) return false;
node->set_input(idx, node->input(i)); if (!IsControlInput(input) && IsReallyConstant(*input_node)) {
} const_inputs.push_back(i);
++idx; } else {
} // Non-const and control inputs.
node->mutable_input()->DeleteSubrange(nonconst_inputs.size(), nonconst_inputs.push_back(i);
const_inputs.size() - 1);
(*node->mutable_attr())["N"].set_i(node->input_size() -
num_control_inputs);
properties->ClearInputProperties(node->name());
(*added_node->mutable_attr())["N"].set_i(const_inputs.size());
return true;
} }
} }
// Promote AccumulateNV2 with all constant inputs to AddN, since it is
// a fake node that cannot be constant folded by itself.
if (const_inputs.size() == num_non_control_inputs &&
node->op() == "AccumulateNV2") {
node->set_op("AddN");
node->mutable_attr()->erase("shape");
return true;
}
const string new_node_name = OptimizedNodeName(
*node, strings::StrCat("_partial_split_", const_inputs.size()));
if (const_inputs.size() > 1 && const_inputs.size() < num_non_control_inputs &&
!node_map_->NodeExists(new_node_name)) {
NodeDef* added_node = optimized_graph->add_node();
*added_node = *node;
// Always use AddN for the constant node, since AccumulateNV2 is a fake
// node that cannot be constant folded, since it does not have a kernel.
added_node->set_op("AddN");
added_node->mutable_attr()->erase("shape");
added_node->set_name(new_node_name);
node_map_->AddNode(added_node->name(), added_node);
added_node->clear_input();
for (int i : const_inputs) {
added_node->add_input(node->input(i));
node_map_->UpdateOutput(NodeName(node->input(i)), node->name(),
added_node->name());
}
// Overwrite the first const input with the added node.
node->set_input(const_inputs[0], added_node->name());
node_map_->AddOutput(added_node->name(), node->name());
nonconst_inputs.push_back(const_inputs[0]);
// Compact the remaining inputs to the original node.
std::sort(nonconst_inputs.begin(), nonconst_inputs.end());
int idx = 0;
for (int i : nonconst_inputs) {
if (idx != i) {
node->set_input(idx, node->input(i));
}
++idx;
}
node->mutable_input()->DeleteSubrange(nonconst_inputs.size(),
const_inputs.size() - 1);
(*node->mutable_attr())["N"].set_i(node->input_size() - num_control_inputs);
properties->ClearInputProperties(node->name());
(*added_node->mutable_attr())["N"].set_i(const_inputs.size());
return true;
}
return false; return false;
} }
@ -3107,156 +3040,176 @@ bool ConstantFolding::PartialConcatConstFolding(GraphDef* optimized_graph,
// Partial constant folding for Concat which is not commutative, so // Partial constant folding for Concat which is not commutative, so
// we have to preserve order and can only push consecutive runs of constant // we have to preserve order and can only push consecutive runs of constant
// inputs into sub-nodes. // inputs into sub-nodes.
const int num_non_control_inputs = NumNonControlInputs(*node); if (!IsConcat(*node) ||
if (IsConcat(*node) && num_non_control_inputs > 3 && node->name().rfind("_partial_split_") != string::npos) {
node->name().rfind("_partial_split_") == string::npos) { return false;
int axis_arg = -1;
int begin = 0;
int end = num_non_control_inputs;
if (node->op() == "Concat") {
begin = 1;
axis_arg = 0;
} else if (node->op() == "ConcatV2") {
end = num_non_control_inputs - 1;
axis_arg = num_non_control_inputs - 1;
} else {
return false;
}
const NodeDef* axis_arg_node =
node_map_->GetNode(NodeName(node->input(axis_arg)));
if (axis_arg_node == nullptr || !IsReallyConstant(*axis_arg_node)) {
// We cannot constant fold Concat unless we the axis argument is
// constant. Skip node.
return false;
}
// We search for consecutive runs of constant inputs in the range
// [begin:end[ and push then down into child nodes.
std::vector<std::pair<int, int>> constant_input_runs;
int first = begin;
int last = begin;
while (last < end) {
while (first < end && !IsReallyConstant(*node_map_->GetNode(
NodeName(node->input(first))))) {
++first;
}
// Invariant: node[first] is constant || first >= end.
last = first + 1;
while (last < end && IsReallyConstant(*node_map_->GetNode(
NodeName(node->input(last))))) {
++last;
}
// Invariant: node[last] is not constant || last >= end
// Discard intervals shorter than 2 elements.
if (first < end && (last - first) > 1) {
constant_input_runs.emplace_back(first, last);
}
first = last;
}
// Skip if all inputs are constant, and let constant folding take over.
if (constant_input_runs.size() == 1 &&
constant_input_runs[0].first == begin &&
constant_input_runs[0].second == end) {
return false;
}
std::set<int> inputs_to_delete;
for (auto interval : constant_input_runs) {
// Push the constant inputs in the interval to a child node than can be
// constant folded.
string new_node_name = OptimizedNodeName(*node, "_partial_split");
do {
new_node_name += strings::StrCat("_", interval.first);
} while (node_map_->NodeExists(new_node_name));
NodeDef* added_node = optimized_graph->add_node();
*added_node = *node;
added_node->set_name(new_node_name);
node_map_->AddNode(added_node->name(), added_node);
added_node->clear_input();
for (int i = interval.first; i < interval.second; ++i) {
added_node->add_input(node->input(i));
node_map_->UpdateOutput(NodeName(node->input(i)), node->name(),
added_node->name());
if (i != interval.first) {
inputs_to_delete.insert(i);
}
}
added_node->add_input(node->input(axis_arg));
(*added_node->mutable_attr())["N"].set_i(interval.second -
interval.first);
node_map_->AddOutput(NodeName(node->input(axis_arg)), added_node->name());
// Overwrite the first constant input with the result of the added
// child node.
node->set_input(interval.first, added_node->name());
node_map_->AddOutput(added_node->name(), node->name());
}
if (!constant_input_runs.empty()) {
if (!inputs_to_delete.empty()) {
// Fix up the inputs to the original node.
std::vector<string> tmp(node->input().begin(), node->input().end());
node->clear_input();
for (int i = 0; i < tmp.size(); ++i) {
if (inputs_to_delete.find(i) == inputs_to_delete.end()) {
node->add_input(tmp[i]);
}
}
(*node->mutable_attr())["N"].set_i(node->input_size() - 1);
properties->ClearInputProperties(node->name());
}
return true;
}
} }
return false; const int num_non_control_inputs = NumNonControlInputs(*node);
if (num_non_control_inputs <= 3) return false;
int axis_arg = -1;
int begin = 0;
int end = num_non_control_inputs;
if (node->op() == "Concat") {
begin = 1;
axis_arg = 0;
} else if (node->op() == "ConcatV2") {
end = num_non_control_inputs - 1;
axis_arg = num_non_control_inputs - 1;
} else {
return false;
}
// We search for consecutive runs of constant inputs in the range
// [begin:end[ and push then down into child nodes.
std::vector<std::pair<int, int>> constant_input_runs;
int first = begin;
int last = begin;
while (last < end) {
while (first < end && !IsReallyConstant(*node_map_->GetNode(
NodeName(node->input(first))))) {
++first;
}
// Invariant: node[first] is constant || first >= end.
last = first + 1;
while (last < end &&
IsReallyConstant(*node_map_->GetNode(NodeName(node->input(last))))) {
++last;
}
// Invariant: node[last] is not constant || last >= end
// Discard intervals shorter than 2 elements.
if (first < end && (last - first) > 1) {
constant_input_runs.emplace_back(first, last);
}
first = last;
}
// Skip if all inputs are constant, and let constant folding take over.
if (constant_input_runs.empty() || (constant_input_runs.size() == 1 &&
constant_input_runs[0].first == begin &&
constant_input_runs[0].second == end)) {
return false;
}
std::set<int> inputs_to_delete;
for (auto interval : constant_input_runs) {
// Push the constant inputs in the interval to a child node than can be
// constant folded.
string new_node_name = OptimizedNodeName(*node, "_partial_split");
do {
new_node_name += strings::StrCat("_", interval.first);
} while (node_map_->NodeExists(new_node_name));
NodeDef* added_node = optimized_graph->add_node();
*added_node = *node;
added_node->set_op("ConcatV2");
added_node->set_name(new_node_name);
node_map_->AddNode(added_node->name(), added_node);
added_node->clear_input();
for (int i = interval.first; i < interval.second; ++i) {
added_node->add_input(node->input(i));
node_map_->UpdateInput(node->name(), node->input(i), added_node->name());
if (i != interval.first) {
inputs_to_delete.insert(i);
}
}
added_node->add_input(node->input(axis_arg));
(*added_node->mutable_attr())["N"].set_i(interval.second - interval.first);
node_map_->AddOutput(NodeName(node->input(axis_arg)), added_node->name());
// Overwrite the first constant input with the result of the added
// child node.
node->set_input(interval.first, added_node->name());
}
if (!constant_input_runs.empty() && !inputs_to_delete.empty()) {
// Fix up the inputs to the original node.
protobuf::RepeatedPtrField<string> tmp;
tmp.Swap(node->mutable_input());
for (int i = 0; i < tmp.size(); ++i) {
if (inputs_to_delete.find(i) == inputs_to_delete.end()) {
node->add_input(tmp.Get(i));
}
}
(*node->mutable_attr())["N"].set_i(node->input_size() - 1);
properties->ClearInputProperties(node->name());
}
return true;
} }
bool ConstantFolding::MergeConcat(const GraphProperties& properties, bool ConstantFolding::GetConcatAxis(const NodeDef& node, int* axis) {
bool use_shape_info, if (node.op() != "ConcatV2") {
return false;
}
int axis_idx = node.input_size() - 1;
while (axis_idx > 0 && IsControlInput(node.input(axis_idx))) {
--axis_idx;
}
if (axis_idx <= 0) {
return false;
}
Tensor axis_tensor;
if (!GetTensorFromConstNode(node.input(axis_idx), &axis_tensor)) {
return false;
}
*axis = axis_tensor.dtype() == DT_INT64
? static_cast<int>(axis_tensor.scalar<int64>()())
: axis_tensor.scalar<int32>()();
return true;
}
bool ConstantFolding::MergeConcat(bool use_shape_info,
GraphDef* optimized_graph, NodeDef* node) { GraphDef* optimized_graph, NodeDef* node) {
// We only optimize for ConcatV2. // We only optimize for ConcatV2.
int axis; int axis;
if (!use_shape_info || !GetConcatAxis(properties, node, &axis) || if (!use_shape_info || !GetConcatAxis(*node, &axis) ||
nodes_to_preserve_.find(node->name()) != nodes_to_preserve_.end() || nodes_to_preserve_.find(node->name()) != nodes_to_preserve_.end() ||
node_map_->GetOutputs(node->name()).size() != 1) { node_map_->GetOutputs(node->name()).size() != 1) {
return false; return false;
} }
// If all inputs are constant, don't merge and let folding take case of it.
const int num_regular_inputs = NumNonControlInputs(*node);
bool all_inputs_are_const = true;
for (int i = 0; i < num_regular_inputs - 1; ++i) {
const NodeDef* input_node = node_map_->GetNode(node->input(i));
if (!IsReallyConstant(*input_node)) {
all_inputs_are_const = false;
}
}
if (all_inputs_are_const) return false;
NodeDef* parent = *node_map_->GetOutputs(node->name()).begin(); NodeDef* parent = *node_map_->GetOutputs(node->name()).begin();
int parent_axis; int parent_axis;
if (!GetConcatAxis(properties, parent, &parent_axis) || axis != parent_axis) { if (!GetConcatAxis(*parent, &parent_axis) || axis != parent_axis) {
return false; return false;
} }
const int index = NumNonControlInputs(*node) - 1; protobuf::RepeatedPtrField<string> parent_inputs;
auto inputs = parent->input(); parent_inputs.Swap(parent->mutable_input());
parent->clear_input(); std::vector<string> ctrl_output;
for (int i = 0; i < inputs.size(); ++i) { // TODO(rmlarsen): IF the child occurs more than once, is it beneficial to
if (IsSameInput(inputs.Get(i), node->name())) { // collapse it into the parent multiple times? Probablyu not.
for (int j = 0; j < node->input_size(); ++j) { for (const auto& input : parent_inputs) {
if (j < index) { if (IsSameInput(input, node->name())) {
// Input tensors (non axis), add to input list of parent. for (int j = 0; j < num_regular_inputs - 1; ++j) {
parent->add_input(node->input(j)); // Add tensor inputs to first child concat tensors (exceptthe final axis
node_map_->RemoveOutput(node->input(j), node->name()); // input) to the parent's inputs.
node_map_->AddOutput(node->input(j), parent->name()); parent->add_input(node->input(j));
} node_map_->UpdateInput(parent->name(), node->name(), node->input(j));
// Skip j == index, which means axis tensor.
if (j > index) {
// Control Dependencies, push back to inputs so they can be forwarded
// to parent.
*inputs.Add() = node->input(j);
}
} }
} else { } else {
parent->add_input(inputs.Get(i)); parent->add_input(input);
} }
} }
// Forward Add control inputs
for (int i = num_regular_inputs; i < node->input_size(); ++i) {
parent->add_input(node->input(i));
node_map_->UpdateInput(parent->name(), node->name(), node->input(i));
}
node->clear_input(); node->clear_input();
node->set_op("NoOp"); node->set_op("NoOp");
node->clear_attr(); node->clear_attr();
node_map_->RemoveNode(node->name()); node_map_->RemoveNode(node->name());
(*parent->mutable_attr())["N"].set_i(NumNonControlInputs(*parent) - 1); (*parent->mutable_attr())["N"].set_i(NumNonControlInputs(*parent) - 1);
DedupControlInputs(parent);
return true; return true;
} }
@ -3344,7 +3297,9 @@ Status ConstantFolding::RunOptimizationPass(Cluster* cluster,
// that the shape inference deals with this conservatively unless we're in // that the shape inference deals with this conservatively unless we're in
// aggressive mode. // aggressive mode.
const bool assume_valid_feeds = opt_level_ == RewriterConfig::AGGRESSIVE; const bool assume_valid_feeds = opt_level_ == RewriterConfig::AGGRESSIVE;
Status s = properties.InferStatically(assume_valid_feeds); Status s = properties.InferStatically(assume_valid_feeds,
/*aggressive_shape_inference=*/false,
/*include_tensor_values=*/false);
const bool can_use_shape_info = s.ok(); const bool can_use_shape_info = s.ok();
if (can_use_shape_info) { if (can_use_shape_info) {

View File

@ -61,6 +61,8 @@ class ConstantFolding : public GraphOptimizer {
bool IsReallyConstant(const NodeDef& node) const; bool IsReallyConstant(const NodeDef& node) const;
bool GetTensorFromConstNode(const string& node_name_or_input, Tensor* tensor);
Status MaterializeShapes(const GraphProperties& properties); Status MaterializeShapes(const GraphProperties& properties);
Status MaterializeBroadcastGradientArgs(const NodeDef& node, Status MaterializeBroadcastGradientArgs(const NodeDef& node,
@ -239,8 +241,9 @@ class ConstantFolding : public GraphOptimizer {
void RemoveSplitOrSplitV(const GraphProperties& properties, void RemoveSplitOrSplitV(const GraphProperties& properties,
GraphDef* optimized_graph, NodeDef* node); GraphDef* optimized_graph, NodeDef* node);
bool MergeConcat(const GraphProperties& properties, bool use_shape_info, bool GetConcatAxis(const NodeDef& node, int* axis);
GraphDef* optimized_graph, NodeDef* node); bool MergeConcat(bool use_shape_info, GraphDef* optimized_graph,
NodeDef* node);
Status AddQuantizedMatMulMinMaxOutConstNodes(NodeDef* node, Status AddQuantizedMatMulMinMaxOutConstNodes(NodeDef* node,
GraphDef* optimized_graph); GraphDef* optimized_graph);

View File

@ -2390,12 +2390,11 @@ TEST_F(ConstantFoldingTest, MergeConcat_PartialFolding) {
TF_EXPECT_OK(status); TF_EXPECT_OK(status);
GraphDef want; GraphDef want;
AddNode("ConstantFolding/concat2_partial_split_0_0", "Const", {}, {}, &want); AddNode("ConstantFolding/concat2_partial_split_0", "Const", {}, {}, &want);
AddNode("axis", "Const", {}, {}, &want); AddNode("axis", "Const", {}, {}, &want);
AddNode("ph", "Placeholder", {}, {}, &want); AddNode("ph", "Placeholder", {}, {}, &want);
AddNode("concat2", "ConcatV2", AddNode("concat2", "ConcatV2",
{"ConstantFolding/concat2_partial_split_0_0", "ph", "axis"}, {}, {"ConstantFolding/concat2_partial_split_0", "ph", "axis"}, {}, &want);
&want);
CompareGraphs(want, got); CompareGraphs(want, got);
} }

View File

@ -2209,7 +2209,10 @@ Status LayoutOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
} }
GraphProperties graph_properties(item); GraphProperties graph_properties(item);
TF_RETURN_IF_ERROR(graph_properties.InferStatically(false)); TF_RETURN_IF_ERROR(
graph_properties.InferStatically(/*assume_valid_feeds=*/false,
/*aggressive_shape_inference=*/false,
/*include_tensor_values=*/false));
GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED(); GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
virtual_placer_.reset(new VirtualPlacer(cluster->GetDevices())); virtual_placer_.reset(new VirtualPlacer(cluster->GetDevices()));

View File

@ -102,7 +102,8 @@ Status IsNodeOutputPortHostFriendly(const GraphView& graph,
if (!properties->has_properties()) { if (!properties->has_properties()) {
// This is an expensive call, call it lazily. // This is an expensive call, call it lazily.
TF_RETURN_IF_ERROR(properties->InferStatically( TF_RETURN_IF_ERROR(properties->InferStatically(
/*assume_valid_feeds=*/false)); /*assume_valid_feeds=*/false, /*aggressive_shape_inference=*/false,
/*include_tensor_values=*/false));
} }
const auto& output_properties = properties->GetOutputProperties(node.name()); const auto& output_properties = properties->GetOutputProperties(node.name());
if (port_id >= output_properties.size()) { if (port_id >= output_properties.size()) {
@ -252,7 +253,8 @@ Status IsNodeHostCandidate(const GraphView& graph, GraphProperties* properties,
if (!properties->has_properties()) { if (!properties->has_properties()) {
// This is an expensive call, call it lazily. // This is an expensive call, call it lazily.
TF_RETURN_IF_ERROR(properties->InferStatically( TF_RETURN_IF_ERROR(properties->InferStatically(
/*assume_valid_feeds=*/false)); /*assume_valid_feeds=*/false, /*aggressive_shape_inference=*/false,
/*include_tensor_values=*/false));
} }
for (const auto& prop : properties->GetOutputProperties(node.name())) { for (const auto& prop : properties->GetOutputProperties(node.name())) {
if (!IsTensorSmall(prop)) { if (!IsTensorSmall(prop)) {

View File

@ -1170,7 +1170,11 @@ Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item,
// Infer properties lazily in case they are not needed. // Infer properties lazily in case they are not needed.
if (!ctx.inferred_graph_properties && IsFusedBatchNormCandidate(node)) { if (!ctx.inferred_graph_properties && IsFusedBatchNormCandidate(node)) {
TF_RETURN_IF_ERROR(ctx.graph_properties.InferStatically(false)); // TODO(rmlarsen): Get rid of tensor value copies.
TF_RETURN_IF_ERROR(ctx.graph_properties.InferStatically(
/*assume_valid_feeds=*/false,
/*aggressive_shape_inference=*/false,
/*include_tensor_values=*/true));
ctx.inferred_graph_properties = true; ctx.inferred_graph_properties = true;
} }

View File

@ -739,9 +739,9 @@ Status ScopedAllocatorOptimizer::Optimize(Cluster* /*cluster*/,
GraphProperties graph_properties(item); GraphProperties graph_properties(item);
const bool assume_valid_feeds = opt_level_ == RewriterConfig::AGGRESSIVE; const bool assume_valid_feeds = opt_level_ == RewriterConfig::AGGRESSIVE;
LOG_WARNING_AND_RETURN_IF_ERROR( LOG_WARNING_AND_RETURN_IF_ERROR(graph_properties.InferStatically(
graph_properties.InferStatically(assume_valid_feeds)); assume_valid_feeds, /*aggressive_shape_inference=*/false,
/*include_tensor_values=*/false));
*optimized_graph = item.graph; *optimized_graph = item.graph;
node_map_.reset(new NodeMap(optimized_graph)); node_map_.reset(new NodeMap(optimized_graph));

View File

@ -87,7 +87,10 @@ Status ShapeOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
graph.GetRegularFanin(MutableGraphView::InputPort(fanout.node, 1)); graph.GetRegularFanin(MutableGraphView::InputPort(fanout.node, 1));
if (!inferred_properties) { if (!inferred_properties) {
// Infer properties lazily in case they are not needed. // Infer properties lazily in case they are not needed.
TF_RETURN_IF_ERROR(properties.InferStatically(false)); TF_RETURN_IF_ERROR(
properties.InferStatically(/*assume_valid_feeds=*/false,
/*aggressive_shape_inference=*/false,
/*include_tensor_values=*/false));
inferred_properties = true; inferred_properties = true;
} }
const auto& prop = const auto& prop =
@ -144,7 +147,10 @@ Status ShapeOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
} }
if (!inferred_properties) { if (!inferred_properties) {
// Infer properties lazily in case they are not needed. // Infer properties lazily in case they are not needed.
TF_RETURN_IF_ERROR(properties.InferStatically(false)); TF_RETURN_IF_ERROR(
properties.InferStatically(/*assume_valid_feeds=*/false,
/*aggressive_shape_inference=*/false,
/*include_tensor_values=*/false));
inferred_properties = true; inferred_properties = true;
} }
const auto& prop1 = properties.GetInputProperties(input1.node->name()); const auto& prop1 = properties.GetInputProperties(input1.node->name());

View File

@ -14,7 +14,9 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/core/grappler/optimizers/static_schedule.h" #include "tensorflow/core/grappler/optimizers/static_schedule.h"
#include <deque> #include <deque>
#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/grappler/costs/graph_properties.h" #include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/core/grappler/costs/op_level_cost_estimator.h" #include "tensorflow/core/grappler/costs/op_level_cost_estimator.h"
@ -92,7 +94,10 @@ Status EstimateEarliestExecutionTimes(
name_map.clear(); name_map.clear();
GraphProperties properties(item); GraphProperties properties(item);
TF_RETURN_IF_ERROR(properties.InferStatically(true)); TF_RETURN_IF_ERROR(
properties.InferStatically(/*assume_valid_feeds=*/true,
/*aggressive_shape_inference=*/false,
/*include_tensor_values=*/false));
OpLevelCostEstimator estimator; OpLevelCostEstimator estimator;
VirtualPlacer placer(cluster->GetDevices()); VirtualPlacer placer(cluster->GetDevices());
@ -160,7 +165,10 @@ Status EstimateRequiredTimes(
} }
} }
GraphProperties properties(item); GraphProperties properties(item);
TF_RETURN_IF_ERROR(properties.InferStatically(true)); TF_RETURN_IF_ERROR(
properties.InferStatically(/*assume_valid_feeds=*/true,
/*aggressive_shape_inference=*/false,
/*include_tensor_values=*/false));
OpLevelCostEstimator estimator; OpLevelCostEstimator estimator;
VirtualPlacer placer(cluster->GetDevices()); VirtualPlacer placer(cluster->GetDevices());

View File

@ -346,7 +346,8 @@ class FoldOldBatchNormsTest : public ::testing::Test {
std::vector<Tensor> fused_outputs; std::vector<Tensor> fused_outputs;
TF_ASSERT_OK(fused_session->Run({}, {"output"}, {}, &fused_outputs)); TF_ASSERT_OK(fused_session->Run({}, {"output"}, {}, &fused_outputs));
test::ExpectTensorNear<float>(original_outputs[0], fused_outputs[0], 2e-5); test::ExpectClose(original_outputs[0], fused_outputs[0], /*atol=*/2e-5,
/*rtol=*/2e-5);
for (const NodeDef& node : fused_graph_def.node()) { for (const NodeDef& node : fused_graph_def.node()) {
EXPECT_NE("FusedBatchNorm", node.op()); EXPECT_NE("FusedBatchNorm", node.op());