Add an option to avoid copying TensorProtos in shape inference when not needed. This saves a significant amount of memory and time in various optimizers requiring static shape inference. Calling GraphProperties::InferStatically would make at least one copy of all constant tensors stored in the graph. This CL adds a new argument to InferStatically, allowing callers to skip copying tensor values to InputProperties and OutputProperties for optimizers that only need shape and datatype information.

Turning this off for constant folding and arithmetic optimizer required cleaning up a few rewrite rules that relied on reading the copy of constant tensor values from GraphProperties instead of just extracting the original value from the corresponding constant node in the graph.

Measured results on Transformer graph:

CPU: Intel Skylake Xeon with HyperThreading (36 cores) dL1:32KB dL2:1024KB dL3:24MB
Benchmark                Time(ns)        CPU(ns) Allocs     Iterations
----------------------------------------------------------------------
Before:
BM_OptimizeTransformer 5161658741     5281083018 23532315            1   139.189MB peak-mem
After:
BM_OptimizeTransformer 4887891650     5005937073 23063225            1   132.652MB peak-mem

Effects on this graph are not dramatic, since it does not contain a lot of large constant tensors. The effect would be much more pronounced on frozen inference graphs.

PiperOrigin-RevId: 250740251
This commit is contained in:
A. Unique TensorFlower 2019-05-30 12:36:28 -07:00 committed by TensorFlower Gardener
parent d7d2307248
commit 705b193812
15 changed files with 643 additions and 663 deletions

View File

@ -712,7 +712,10 @@ class SymbolicShapeRefiner {
// Perform inference on function body. // Perform inference on function body.
GraphProperties gp(grappler_function_item); GraphProperties gp(grappler_function_item);
TF_RETURN_IF_ERROR(gp.InferStatically(true, aggressive_shape_inference_)); TF_RETURN_IF_ERROR(gp.InferStatically(
/*assume_valid_feeds=*/true,
/*aggressive_shape_inference=*/aggressive_shape_inference_,
/*include_tensor_values=*/true));
// Add return nodes for output shapes. // Add return nodes for output shapes.
int output = 0; int output = 0;
@ -2066,7 +2069,8 @@ Status GraphProperties::UpdateEnqueue(
} }
Status GraphProperties::InferStatically(bool assume_valid_feeds, Status GraphProperties::InferStatically(bool assume_valid_feeds,
bool aggressive_shape_inference) { bool aggressive_shape_inference,
bool include_tensor_values) {
FunctionLibraryDefinition function_library(OpRegistry::Global(), FunctionLibraryDefinition function_library(OpRegistry::Global(),
item_.graph.library()); item_.graph.library());
std::unordered_map<string, std::unordered_set<int>> fed_ports; std::unordered_map<string, std::unordered_set<int>> fed_ports;
@ -2225,9 +2229,11 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds,
&input_properties[i]); &input_properties[i]);
input.port_id = i; input.port_id = i;
GraphView::OutputPort fanin = graph_view.GetRegularFanin(input); GraphView::OutputPort fanin = graph_view.GetRegularFanin(input);
if (include_tensor_values) {
// Export tensor value to input_properties.value. // Export tensor value to input_properties.value.
if (IsConstant(*fanin.node)) { if (IsConstant(*fanin.node)) {
const TensorProto& raw_val = fanin.node->attr().at("value").tensor(); const TensorProto& raw_val =
fanin.node->attr().at("value").tensor();
*input_properties[i].mutable_value() = raw_val; *input_properties[i].mutable_value() = raw_val;
} else if (ctx->input_tensor_protos.size() > i && } else if (ctx->input_tensor_protos.size() > i &&
ctx->input_tensor_protos[i] != nullptr) { ctx->input_tensor_protos[i] != nullptr) {
@ -2242,6 +2248,7 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds,
} }
} }
} }
}
// Fill output properties. // Fill output properties.
{ {
@ -2254,13 +2261,16 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds,
for (int i = 0; i < ic->num_outputs(); ++i) { for (int i = 0; i < ic->num_outputs(); ++i) {
shape_manager->AsTensorProperties(ic->output(i), ctx->output_types[i], shape_manager->AsTensorProperties(ic->output(i), ctx->output_types[i],
&output_properties[i]); &output_properties[i]);
if (include_tensor_values) {
// Export tensor value to output_properties.value. // Export tensor value to output_properties.value.
if (IsConstant(node)) { if (IsConstant(node)) {
// TODO(rmlarsen): Eliminate this copy.
const TensorProto& raw_val = node.attr().at("value").tensor(); const TensorProto& raw_val = node.attr().at("value").tensor();
*output_properties[i].mutable_value() = raw_val; *output_properties[i].mutable_value() = raw_val;
} else if (ctx->output_tensor_protos.size() > i && } else if (ctx->output_tensor_protos.size() > i &&
ctx->output_tensor_protos[i] != nullptr) { ctx->output_tensor_protos[i] != nullptr) {
*output_properties[i].mutable_value() = *ctx->output_tensor_protos[i]; *output_properties[i].mutable_value() =
*ctx->output_tensor_protos[i];
} else if (ctx->output_tensors_as_shapes.size() > i && } else if (ctx->output_tensors_as_shapes.size() > i &&
IsShapeFullyDefinedIntegerVectorOrScalar( IsShapeFullyDefinedIntegerVectorOrScalar(
ic, ic->output(i), ctx->output_tensors_as_shapes[i], ic, ic->output(i), ctx->output_tensors_as_shapes[i],
@ -2272,6 +2282,7 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds,
} }
} }
} }
}
// Help trace the unknown dimensions to their origins. // Help trace the unknown dimensions to their origins.
VerboseLogUnknownDimensionSources(item_.graph, input_properties_, VerboseLogUnknownDimensionSources(item_.graph, input_properties_,

View File

@ -89,11 +89,15 @@ class GraphProperties {
// output values when possible and does other aggressive strategies. // output values when possible and does other aggressive strategies.
// Similar to assuming_valid_feeds, this may cause incorrectness in graph // Similar to assuming_valid_feeds, this may cause incorrectness in graph
// analyses, but is useful for simulation or scheduling. // analyses, but is useful for simulation or scheduling.
// If include_values is true, the values of constant tensors will be
// included in the input and output properties.
Status InferStatically(bool assume_valid_feeds, Status InferStatically(bool assume_valid_feeds,
bool aggressive_shape_inference); bool aggressive_shape_inference,
bool include_tensor_values);
Status InferStatically(bool assume_valid_feeds) { Status InferStatically(bool assume_valid_feeds) {
return InferStatically(assume_valid_feeds, return InferStatically(assume_valid_feeds,
/*aggressive_shape_inference=*/false); /*aggressive_shape_inference=*/false,
/*include_tensor_values=*/true);
} }
// Infer the shape by running the graph on the specified cluster and recording // Infer the shape by running the graph on the specified cluster and recording
// the shapes of the processed tensors. // the shapes of the processed tensors.
@ -117,6 +121,7 @@ class GraphProperties {
const string& node_name) const; const string& node_name) const;
const std::vector<OpInfo::TensorProperties>& GetOutputProperties( const std::vector<OpInfo::TensorProperties>& GetOutputProperties(
const string& node_name) const; const string& node_name) const;
// Invalidate input/output properties for nodes modified during graph // Invalidate input/output properties for nodes modified during graph
// optimization pass, to prevent potential optimizations, based on incorrect // optimization pass, to prevent potential optimizations, based on incorrect
// shape information. // shape information.

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/core/grappler/costs/graph_properties.h" #include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/framework/scope.h"
#include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/graph_def_util.h" #include "tensorflow/core/framework/graph_def_util.h"
@ -1009,7 +1010,8 @@ TEST_F(GraphPropertiesTest, SkippingValueInferenceForLargeTensors) {
GraphProperties properties(item); GraphProperties properties(item);
TF_CHECK_OK(properties.InferStatically( TF_CHECK_OK(properties.InferStatically(
/*assume_valid_feeds=*/false, /*assume_valid_feeds=*/false,
/*aggressive_shape_inference=*/true)); /*aggressive_shape_inference=*/true,
/*include_tensor_values=*/true));
const auto out_props = properties.GetOutputProperties("fill"); const auto out_props = properties.GetOutputProperties("fill");
const OpInfo::TensorProperties out_prop0 = out_props[0]; const OpInfo::TensorProperties out_prop0 = out_props[0];
EXPECT_EQ("float: [4,4]", PropToString(out_prop0)); EXPECT_EQ("float: [4,4]", PropToString(out_prop0));
@ -1028,7 +1030,8 @@ TEST_F(GraphPropertiesTest, SkippingValueInferenceForLargeTensors) {
GraphProperties properties(item); GraphProperties properties(item);
TF_CHECK_OK(properties.InferStatically( TF_CHECK_OK(properties.InferStatically(
/*assume_valid_feeds=*/false, /*assume_valid_feeds=*/false,
/*aggressive_shape_inference=*/true)); /*aggressive_shape_inference=*/true,
/*include_tensor_values=*/true));
const auto out_props = properties.GetOutputProperties("fill"); const auto out_props = properties.GetOutputProperties("fill");
const OpInfo::TensorProperties out_prop0 = out_props[0]; const OpInfo::TensorProperties out_prop0 = out_props[0];
EXPECT_EQ("float: [1000,1000,1000,1000]", PropToString(out_prop0)); EXPECT_EQ("float: [1000,1000,1000,1000]", PropToString(out_prop0));
@ -1248,10 +1251,12 @@ TEST_F(GraphPropertiesTest, ArithmeticFunctionReturnTensorValue) {
// evaluate output value. // evaluate output value.
TF_CHECK_OK(properties.InferStatically( TF_CHECK_OK(properties.InferStatically(
/*assume_valid_feeds=*/true, /*assume_valid_feeds=*/true,
/*aggressive_shape_inference=*/false)); /*aggressive_shape_inference=*/false,
/*include_tensor_values=*/true));
const auto out_props = properties.GetOutputProperties("MyFunc"); const auto out_props = properties.GetOutputProperties("MyFunc");
const OpInfo::TensorProperties out_prop0 = out_props[0]; const OpInfo::TensorProperties out_prop0 = out_props[0];
EXPECT_EQ("int32: [2]", PropToString(out_prop0)); EXPECT_EQ("int32: [2]", PropToString(out_prop0));
LOG(INFO) << out_prop0.DebugString();
EXPECT_FALSE(out_prop0.has_value()); EXPECT_FALSE(out_prop0.has_value());
} }
@ -1260,7 +1265,8 @@ TEST_F(GraphPropertiesTest, ArithmeticFunctionReturnTensorValue) {
// With aggressive_shape_inference, output value is evaluated. // With aggressive_shape_inference, output value is evaluated.
TF_CHECK_OK(properties.InferStatically( TF_CHECK_OK(properties.InferStatically(
/*assume_valid_feeds=*/true, /*assume_valid_feeds=*/true,
/*aggressive_shape_inference=*/true)); /*aggressive_shape_inference=*/true,
/*include_tensor_values=*/true));
const auto out_props = properties.GetOutputProperties("MyFunc"); const auto out_props = properties.GetOutputProperties("MyFunc");
const OpInfo::TensorProperties out_prop0 = out_props[0]; const OpInfo::TensorProperties out_prop0 = out_props[0];
EXPECT_EQ("int32: [2]", PropToString(out_prop0)); EXPECT_EQ("int32: [2]", PropToString(out_prop0));
@ -1802,7 +1808,8 @@ TEST_F(GraphPropertiesTest, StridedSliceOfShapeWithShrinkAxisMask) {
GraphProperties properties(item); GraphProperties properties(item);
TF_CHECK_OK(properties.InferStatically( TF_CHECK_OK(properties.InferStatically(
/*assume_valid_feeds=*/false, /*assume_valid_feeds=*/false,
/*aggressive_shape_inference=*/false)); /*aggressive_shape_inference=*/false,
/*include_tensor_values=*/true));
EXPECT_FALSE(properties.GetOutputProperties("slice").at(0).has_value()); EXPECT_FALSE(properties.GetOutputProperties("slice").at(0).has_value());
} }
@ -1812,7 +1819,8 @@ TEST_F(GraphPropertiesTest, StridedSliceOfShapeWithShrinkAxisMask) {
GraphProperties properties(item); GraphProperties properties(item);
TF_CHECK_OK(properties.InferStatically( TF_CHECK_OK(properties.InferStatically(
/*assume_valid_feeds=*/false, /*assume_valid_feeds=*/false,
/*aggressive_shape_inference=*/true)); /*aggressive_shape_inference=*/true,
/*include_tensor_values=*/true));
EXPECT_TRUE(properties.GetOutputProperties("slice").at(0).has_value()); EXPECT_TRUE(properties.GetOutputProperties("slice").at(0).has_value());
const auto slice_value = const auto slice_value =
properties.GetOutputProperties("slice").at(0).value(); properties.GetOutputProperties("slice").at(0).value();
@ -1838,7 +1846,8 @@ TEST_F(GraphPropertiesTest, ValuePropagationThroughArithmeticOps) {
GraphProperties properties(item); GraphProperties properties(item);
TF_CHECK_OK(properties.InferStatically( TF_CHECK_OK(properties.InferStatically(
/*assume_valid_feeds=*/false, /*assume_valid_feeds=*/false,
/*aggressive_shape_inference=*/true)); /*aggressive_shape_inference=*/true,
/*include_tensor_values=*/true));
// Check output shapes and values. // Check output shapes and values.
const auto& a_plus_one_prop = properties.GetOutputProperties("a_plus_one")[0]; const auto& a_plus_one_prop = properties.GetOutputProperties("a_plus_one")[0];
@ -1881,7 +1890,8 @@ TEST_F(GraphPropertiesTest, ShapeAnnotation) {
// Without aggressive_shape_inference, ignore annotated information. // Without aggressive_shape_inference, ignore annotated information.
TF_CHECK_OK(properties.InferStatically( TF_CHECK_OK(properties.InferStatically(
/*assume_valid_feeds=*/false, /*assume_valid_feeds=*/false,
/*aggressive_shape_inference=*/false)); /*aggressive_shape_inference=*/false,
/*include_tensor_values=*/true));
const auto props = properties.GetOutputProperties("Identity"); const auto props = properties.GetOutputProperties("Identity");
EXPECT_EQ(1, props.size()); EXPECT_EQ(1, props.size());
const OpInfo::TensorProperties& prop = props[0]; const OpInfo::TensorProperties& prop = props[0];
@ -1895,7 +1905,8 @@ TEST_F(GraphPropertiesTest, ShapeAnnotation) {
// Use annotated information. // Use annotated information.
TF_CHECK_OK(properties.InferStatically( TF_CHECK_OK(properties.InferStatically(
/*assume_valid_feeds=*/false, /*assume_valid_feeds=*/false,
/*aggressive_shape_inference=*/true)); /*aggressive_shape_inference=*/true,
/*include_tensor_values=*/true));
const auto props = properties.GetOutputProperties("Identity"); const auto props = properties.GetOutputProperties("Identity");
EXPECT_EQ(1, props.size()); EXPECT_EQ(1, props.size());
const OpInfo::TensorProperties& prop = props[0]; const OpInfo::TensorProperties& prop = props[0];
@ -1923,7 +1934,8 @@ TEST_F(GraphPropertiesTest, ShapeAnnotationWithCompatibleShapes) {
// Use annotated information. // Use annotated information.
TF_CHECK_OK(properties.InferStatically( TF_CHECK_OK(properties.InferStatically(
/*assume_valid_feeds=*/false, /*assume_valid_feeds=*/false,
/*aggressive_shape_inference=*/true)); /*aggressive_shape_inference=*/true,
/*include_tensor_values=*/true));
const auto props = properties.GetOutputProperties("Identity"); const auto props = properties.GetOutputProperties("Identity");
EXPECT_EQ(1, props.size()); EXPECT_EQ(1, props.size());
const OpInfo::TensorProperties& prop = props[0]; const OpInfo::TensorProperties& prop = props[0];
@ -1950,7 +1962,8 @@ TEST_F(GraphPropertiesTest, ShapeAnnotationWithIncompatibleShapes) {
// Use annotated information. // Use annotated information.
TF_CHECK_OK(properties.InferStatically( TF_CHECK_OK(properties.InferStatically(
/*assume_valid_feeds=*/false, /*assume_valid_feeds=*/false,
/*aggressive_shape_inference=*/true)); /*aggressive_shape_inference=*/true,
/*include_tensor_values=*/true));
const auto props = properties.GetOutputProperties("Identity"); const auto props = properties.GetOutputProperties("Identity");
EXPECT_EQ(1, props.size()); EXPECT_EQ(1, props.size());
const OpInfo::TensorProperties& prop = props[0]; const OpInfo::TensorProperties& prop = props[0];
@ -1977,7 +1990,8 @@ TEST_F(GraphPropertiesTest, ShapeAnnotationWithoutInferenceFn) {
// Use annotated information. // Use annotated information.
TF_CHECK_OK(properties.InferStatically( TF_CHECK_OK(properties.InferStatically(
/*assume_valid_feeds=*/false, /*assume_valid_feeds=*/false,
/*aggressive_shape_inference=*/true)); /*aggressive_shape_inference=*/true,
/*include_tensor_values=*/true));
const auto props = properties.GetOutputProperties("TestOpWithNoInferenceFn"); const auto props = properties.GetOutputProperties("TestOpWithNoInferenceFn");
EXPECT_EQ(1, props.size()); EXPECT_EQ(1, props.size());
const OpInfo::TensorProperties& prop = props[0]; const OpInfo::TensorProperties& prop = props[0];

View File

@ -359,7 +359,7 @@ Status VirtualScheduler::Init(const GrapplerItem* item) {
graph_properties_ = absl::make_unique<GraphProperties>(*item); graph_properties_ = absl::make_unique<GraphProperties>(*item);
if (use_static_shapes_) { if (use_static_shapes_) {
TF_RETURN_IF_ERROR(graph_properties_->InferStatically( TF_RETURN_IF_ERROR(graph_properties_->InferStatically(
true, use_aggressive_shape_inference_)); true, use_aggressive_shape_inference_, true));
} else { } else {
TF_RETURN_IF_ERROR(graph_properties_->InferDynamically(cluster_)); TF_RETURN_IF_ERROR(graph_properties_->InferDynamically(cluster_));
} }

View File

@ -234,6 +234,14 @@ class ArithmeticOptimizerStage : public GraphOptimizerStage<string> {
DedupControlInputs(target_node); DedupControlInputs(target_node);
} }
bool IsReallyConstant(const NodeDef& node) const {
if (!IsConstant(node)) {
return false;
}
// If the node is fed it's not constant anymore.
return ctx().feed_nodes->find(node.name()) == ctx().feed_nodes->end();
}
bool IsInPreserveSet(const NodeDef& node) const { bool IsInPreserveSet(const NodeDef& node) const {
return ctx().nodes_to_preserve->find(node.name()) != return ctx().nodes_to_preserve->find(node.name()) !=
ctx().nodes_to_preserve->end(); ctx().nodes_to_preserve->end();
@ -259,6 +267,14 @@ class ArithmeticOptimizerStage : public GraphOptimizerStage<string> {
return false; return false;
} }
bool GetTensorFromConstNode(const string& node_name_or_input,
Tensor* tensor) {
const NodeDef* node = ctx().node_map->GetNode(node_name_or_input);
return node != nullptr && IsReallyConstant(*node) &&
CheckAttrExists(*node, "value").ok() &&
tensor->FromProto(node->attr().at("value").tensor());
}
private: private:
// Extended context required for ArithmeticOptimizer. // Extended context required for ArithmeticOptimizer.
const ArithmeticOptimizerContext ctx_ext_; const ArithmeticOptimizerContext ctx_ext_;
@ -2480,27 +2496,16 @@ class ConvertPowStage : public ArithmeticOptimizerStage {
bool IsSupported(const NodeDef* node) const override { bool IsSupported(const NodeDef* node) const override {
return IsPow(*node) && return IsPow(*node) &&
ctx().graph_properties->GetInputProperties(node->name()).size() == 2; ctx().graph_properties->HasOutputProperties(node->name()) &&
ctx().graph_properties->HasInputProperties(node->name());
} }
Status TrySimplify(NodeDef* node, string* simplified_node_name) override { Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
const auto& pow_props = Tensor pow;
ctx().graph_properties->GetInputProperties(node->name())[1]; if (!GetTensorFromConstNode(node->input(1), &pow)) return Status::OK();
PartialTensorShape shape(pow_props.shape());
if (!shape.IsFullyDefined()) {
// skip if p is not fully defined.
return Status::OK();
}
if (TensorShape::IsValid(pow_props.shape()) && pow_props.has_value()) {
Tensor pow(pow_props.dtype(), pow_props.shape());
if (!pow.FromProto(pow_props.value())) {
return errors::InvalidArgument("Cannot parse tensor from proto: ",
pow_props.value().DebugString());
}
complex128 prev, curr; complex128 prev, curr;
for (int i = 0; i < pow.NumElements(); ++i) { for (int i = 0; i < pow.NumElements(); ++i) {
if (!GetElementUnexhaustive(pow, i, {pow_props.dtype()}, &curr)) { if (!GetElementUnexhaustive(pow, i, {pow.dtype()}, &curr)) {
// input data type is not supported by Pow. Skip. // input data type is not supported by Pow. Skip.
return Status::OK(); return Status::OK();
} }
@ -2536,34 +2541,23 @@ class ConvertPowStage : public ArithmeticOptimizerStage {
AddToOptimizationQueue(node); AddToOptimizationQueue(node);
AddToOptimizationQueue(y); AddToOptimizationQueue(y);
} else if (curr == complex128(0, 0) && } else if (curr == complex128(0, 0) &&
ShapesSymbolicallyEqual(value_props.shape(), output_shape)) { ShapesSymbolicallyEqual(value_props.shape(), output_shape) &&
PartialTensorShape shape(value_props.shape()); PartialTensorShape(output_shape).IsFullyDefined()) {
if (!shape.IsFullyDefined()) { const auto dtype = node->attr().at("T").type();
// skip if b is not fully defined. Tensor ones(dtype, output_shape);
return Status::OK(); for (int i = 0; i < ones.NumElements(); ++i) {
} TF_RETURN_IF_ERROR(SetElementToOne(i, &ones));
if (TensorShape::IsValid(value_props.shape()) &&
value_props.has_value()) {
Tensor base(value_props.dtype(), value_props.shape());
if (!base.FromProto(value_props.value())) {
return errors::InvalidArgument("Cannot parse tensor from proto: ",
value_props.value().DebugString());
} }
node->set_op("Const"); node->set_op("Const");
Tensor c(base.dtype(), base.shape()); (*node->mutable_attr())["dtype"].set_type(dtype);
for (int i = 0; i < c.NumElements(); ++i) {
TF_RETURN_IF_ERROR(SetElementToOne(i, &c));
}
(*node->mutable_attr())["dtype"].set_type(base.dtype());
c.AsProtoTensorContent(
(*node->mutable_attr())["value"].mutable_tensor());
node->mutable_attr()->erase("T"); node->mutable_attr()->erase("T");
ones.AsProtoTensorContent(
(*node->mutable_attr())["value"].mutable_tensor());
node->set_input(0, AsControlDependency(x->name())); node->set_input(0, AsControlDependency(x->name()));
node->set_input(1, AsControlDependency(y->name())); node->set_input(1, AsControlDependency(y->name()));
AddToOptimizationQueue(node); AddToOptimizationQueue(node);
AddToOptimizationQueue(x); AddToOptimizationQueue(x);
AddToOptimizationQueue(y); AddToOptimizationQueue(y);
}
} else if (curr == complex128(-0.5, 0)) { } else if (curr == complex128(-0.5, 0)) {
node->set_op("Rsqrt"); node->set_op("Rsqrt");
node->set_input(1, AsControlDependency(y->name())); node->set_input(1, AsControlDependency(y->name()));
@ -2575,7 +2569,6 @@ class ConvertPowStage : public ArithmeticOptimizerStage {
AddToOptimizationQueue(node); AddToOptimizationQueue(node);
AddToOptimizationQueue(y); AddToOptimizationQueue(y);
} }
}
return Status::OK(); return Status::OK();
} }
@ -2638,12 +2631,12 @@ class ConvertLog1pStage : public ArithmeticOptimizerStage {
} }
private: private:
Status TrySimplifyInternal(NodeDef* node, NodeDef* input, int i, int j, Status TrySimplifyInternal(NodeDef* node, NodeDef* add_node, int i, int j,
bool* modified) { bool* modified) {
const auto& t = const auto& t =
ctx().graph_properties->GetInputProperties(input->name())[i]; ctx().graph_properties->GetInputProperties(add_node->name())[i];
const auto& c = const auto& c =
ctx().graph_properties->GetInputProperties(input->name())[j]; ctx().graph_properties->GetInputProperties(add_node->name())[j];
for (int k = 0; k < c.shape().dim_size(); ++k) { for (int k = 0; k < c.shape().dim_size(); ++k) {
// Skip if c shape is not fully determined. // Skip if c shape is not fully determined.
if (c.shape().dim(k).size() < 0) { if (c.shape().dim(k).size() < 0) {
@ -2659,13 +2652,13 @@ class ConvertLog1pStage : public ArithmeticOptimizerStage {
// broadcast. // broadcast.
return Status::OK(); return Status::OK();
} }
if (TensorShape::IsValid(c.shape()) && c.has_value()) { Tensor constant;
Tensor constant(c.dtype(), c.shape()); if (GetTensorFromConstNode(add_node->input(j), &constant)) {
if (!constant.FromProto(c.value())) {
return errors::InvalidArgument("Cannot parse tensor from proto: ",
c.value().DebugString());
}
complex128 element; complex128 element;
// TODO(rmlarsen): Refactor the more general IsOnes from
// constant_folding.cc and use it here. Perhaps also convert log(x - (-1))
// or (preferably) add a passes to canonicalize Sub(x, -1) to Add(x, 1),
// and Neg(-1) to 1.
for (int k = 0; k < constant.NumElements(); ++k) { for (int k = 0; k < constant.NumElements(); ++k) {
if (!GetElementUnexhaustive(constant, k, if (!GetElementUnexhaustive(constant, k,
{DT_BFLOAT16, DT_HALF, DT_FLOAT, DT_DOUBLE, {DT_BFLOAT16, DT_HALF, DT_FLOAT, DT_DOUBLE,
@ -2680,15 +2673,15 @@ class ConvertLog1pStage : public ArithmeticOptimizerStage {
} }
} }
NodeDef *x, *y; NodeDef *x, *y;
TF_RETURN_IF_ERROR(GetInputNode(input->input(i), &x)); TF_RETURN_IF_ERROR(GetInputNode(add_node->input(i), &x));
TF_RETURN_IF_ERROR(GetInputNode(input->input(j), &y)); TF_RETURN_IF_ERROR(GetInputNode(add_node->input(j), &y));
node->set_op("Log1p"); node->set_op("Log1p");
node->set_input(0, input->input(i)); node->set_input(0, add_node->input(i));
node->add_input(AsControlDependency(y->name())); node->add_input(AsControlDependency(y->name()));
ForwardControlDependencies(node, {input}); ForwardControlDependencies(node, {add_node});
AddToOptimizationQueue(node); AddToOptimizationQueue(node);
AddToOptimizationQueue(input); AddToOptimizationQueue(add_node);
AddToOptimizationQueue(x); AddToOptimizationQueue(x);
AddToOptimizationQueue(y); AddToOptimizationQueue(y);
*modified = true; *modified = true;
@ -2717,25 +2710,8 @@ class ConvertExpm1Stage : public ArithmeticOptimizerStage {
if (ctx().graph_properties->GetInputProperties(node->name()).size() < 2) { if (ctx().graph_properties->GetInputProperties(node->name()).size() < 2) {
return Status::OK(); return Status::OK();
} }
const auto& t = ctx().graph_properties->GetInputProperties(node->name())[0];
NodeDef* exp;
TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &exp));
if (!IsExp(*exp)) {
return Status::OK();
}
if (ctx().graph_properties->GetInputProperties(exp->name()).empty()) {
return Status::OK();
}
const auto& t = ctx().graph_properties->GetInputProperties(exp->name())[0];
const auto& c = ctx().graph_properties->GetInputProperties(node->name())[1]; const auto& c = ctx().graph_properties->GetInputProperties(node->name())[1];
for (int k = 0; k < c.shape().dim_size(); ++k) {
// Skip if c shape is not fully determined.
if (c.shape().dim(k).size() < 0) {
return Status::OK();
}
}
TensorShapeProto broadcast_shape; TensorShapeProto broadcast_shape;
if (!ShapeAfterBroadcast(t.shape(), c.shape(), &broadcast_shape)) { if (!ShapeAfterBroadcast(t.shape(), c.shape(), &broadcast_shape)) {
return Status::OK(); return Status::OK();
@ -2745,12 +2721,9 @@ class ConvertExpm1Stage : public ArithmeticOptimizerStage {
// broadcast. // broadcast.
return Status::OK(); return Status::OK();
} }
if (TensorShape::IsValid(c.shape()) && c.has_value()) { Tensor constant;
Tensor constant(c.dtype(), c.shape()); if (!GetTensorFromConstNode(node->input(1), &constant)) return Status::OK();
if (!constant.FromProto(c.value())) { // TODO(rmlarsen): Use the more general IsOnes helper here.
return errors::InvalidArgument("Cannot parse tensor from proto: ",
c.value().DebugString());
}
complex128 element; complex128 element;
for (int k = 0; k < constant.NumElements(); ++k) { for (int k = 0; k < constant.NumElements(); ++k) {
if (!GetElementUnexhaustive(constant, k, if (!GetElementUnexhaustive(constant, k,
@ -2760,11 +2733,14 @@ class ConvertExpm1Stage : public ArithmeticOptimizerStage {
// input data type is not supported by expm1. Skip. // input data type is not supported by expm1. Skip.
return Status::OK(); return Status::OK();
} }
LOG(INFO) << "Got element = " << element;
if (element != complex128(1)) { if (element != complex128(1)) {
// current element is not 1. Skip. // current element is not 1. Skip.
return Status::OK(); return Status::OK();
} }
} }
NodeDef* exp;
TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &exp));
NodeDef *exp_input, *ones; NodeDef *exp_input, *ones;
TF_RETURN_IF_ERROR(GetInputNode(exp->input(0), &exp_input)); TF_RETURN_IF_ERROR(GetInputNode(exp->input(0), &exp_input));
TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &ones)); TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &ones));
@ -2777,7 +2753,7 @@ class ConvertExpm1Stage : public ArithmeticOptimizerStage {
AddToOptimizationQueue(exp); AddToOptimizationQueue(exp);
AddToOptimizationQueue(exp_input); AddToOptimizationQueue(exp_input);
AddToOptimizationQueue(ones); AddToOptimizationQueue(ones);
} *simplified_node_name = node->name();
return Status::OK(); return Status::OK();
} }
}; };
@ -3096,14 +3072,6 @@ class RemoveStackStridedSliceSameAxis : public ArithmeticOptimizerStage {
} }
protected: protected:
bool IsReallyConstant(const NodeDef& node) const {
if (!IsConstant(node)) {
return false;
}
// If the node is fed it's not constant anymore.
return ctx().feed_nodes->find(node.name()) == ctx().feed_nodes->end();
}
bool GetConstantAsInt64(const NodeDef& node, DataType dtype, bool GetConstantAsInt64(const NodeDef& node, DataType dtype,
std::vector<int64>* values) { std::vector<int64>* values) {
if (dtype == DT_INT32) { if (dtype == DT_INT32) {
@ -3430,8 +3398,6 @@ bool ArithmeticOptimizer::CanDedup(const NodeDef& node) const {
void ArithmeticOptimizer::DedupComputations() { void ArithmeticOptimizer::DedupComputations() {
CanonicalizeGraph(optimized_graph_); CanonicalizeGraph(optimized_graph_);
// LOG(INFO) << "Graph after canonicalization: \n"
// << optimized_graph_->DebugString();
GraphTopologyView graph_view; GraphTopologyView graph_view;
if (!graph_view.InitializeFromGraph(*optimized_graph_).ok()) { if (!graph_view.InitializeFromGraph(*optimized_graph_).ok()) {
@ -3683,7 +3649,10 @@ Status ArithmeticOptimizer::Optimize(Cluster* /*cluster*/,
graph_properties_.reset(new GraphProperties(optimized_item)); graph_properties_.reset(new GraphProperties(optimized_item));
const bool assume_valid_feeds = opt_level_ == RewriterConfig::AGGRESSIVE; const bool assume_valid_feeds = opt_level_ == RewriterConfig::AGGRESSIVE;
const Status status = graph_properties_->InferStatically(assume_valid_feeds); const Status status =
graph_properties_->InferStatically(assume_valid_feeds,
/*aggressive_shape_inference=*/false,
/*include_tensor_values=*/false);
const bool can_use_shapes = status.ok(); const bool can_use_shapes = status.ok();
if (!can_use_shapes) { if (!can_use_shapes) {
VLOG(1) << "Shape inference failed." << status.error_message(); VLOG(1) << "Shape inference failed." << status.error_message();

View File

@ -122,27 +122,6 @@ bool MaybeRemoveControlInput(const string& old_input, NodeDef* node,
return removed_input; return removed_input;
} }
bool GetConcatAxis(const GraphProperties& properties, NodeDef* node,
int* axis) {
if (node->op() != "ConcatV2" ||
properties.GetInputProperties(node->name()).empty()) {
return false;
}
const auto& axis_input = properties.GetInputProperties(node->name()).back();
if (!TensorShape::IsValid(axis_input.shape()) || !axis_input.has_value()) {
return false;
}
Tensor axis_tensor(axis_input.dtype(), axis_input.shape());
if (!axis_tensor.FromProto(axis_input.value())) {
return false;
}
*axis = axis_input.dtype() == DT_INT64
? static_cast<int>(axis_tensor.scalar<int64>()())
: axis_tensor.scalar<int32>()();
return true;
}
bool HasTPUAttributes(const NodeDef& node) { bool HasTPUAttributes(const NodeDef& node) {
AttrSlice attrs(node); AttrSlice attrs(node);
for (auto attr : attrs) { for (auto attr : attrs) {
@ -220,9 +199,9 @@ string ConstantFolding::AddControlDependency(const string& input_name,
if (IsControlInput(input_name)) { if (IsControlInput(input_name)) {
return input_name; return input_name;
} }
const NodeDef* node = node_map->GetNode(input_name); const NodeDef& node = *node_map->GetNode(input_name);
if (!IsSwitch(*node)) { if (!IsSwitch(node)) {
return AsControlDependency(*node); return AsControlDependency(node);
} else { } else {
// We can't anchor control dependencies directly on the switch node: unlike // We can't anchor control dependencies directly on the switch node: unlike
// other nodes only one of the outputs of the switch node will be generated // other nodes only one of the outputs of the switch node will be generated
@ -230,10 +209,10 @@ string ConstantFolding::AddControlDependency(const string& input_name,
// dependency is only triggered when the corresponding output is triggered. // dependency is only triggered when the corresponding output is triggered.
// We start by looking for an identity node connected to the output of the // We start by looking for an identity node connected to the output of the
// switch node, and use it to anchor the control dependency. // switch node, and use it to anchor the control dependency.
auto outputs = node_map->GetOutputs(node->name()); auto outputs = node_map->GetOutputs(node.name());
for (const NodeDef* output : outputs) { for (const NodeDef* output : outputs) {
if (IsIdentity(*output) || IsIdentityNSingleInput(*output)) { if (IsIdentity(*output) || IsIdentityNSingleInput(*output)) {
if (IsSameInput(node->input(0), input_name)) { if (IsSameInput(node.input(0), input_name)) {
return AsControlDependency(*output); return AsControlDependency(*output);
} }
} }
@ -244,19 +223,19 @@ string ConstantFolding::AddControlDependency(const string& input_name,
string ctrl_dep_name = ParseNodeName(input_name, &port); string ctrl_dep_name = ParseNodeName(input_name, &port);
strings::StrAppend(&ctrl_dep_name, "_", port); strings::StrAppend(&ctrl_dep_name, "_", port);
ctrl_dep_name = AddPrefixToNodeName(ctrl_dep_name, kConstantFoldingCtrl); ctrl_dep_name = AddPrefixToNodeName(ctrl_dep_name, kConstantFoldingCtrl);
const DataType output_type = node->attr().at("T").type(); const DataType output_type = node.attr().at("T").type();
NodeDef* added_node = node_map->GetNode(ctrl_dep_name); NodeDef* added_node = node_map->GetNode(ctrl_dep_name);
if (added_node == nullptr) { if (added_node == nullptr) {
added_node = graph->add_node(); added_node = graph->add_node();
added_node->set_name(ctrl_dep_name); added_node->set_name(ctrl_dep_name);
added_node->set_op("Identity"); added_node->set_op("Identity");
added_node->set_device(node->device()); added_node->set_device(node.device());
(*added_node->mutable_attr())["T"].set_type(output_type); (*added_node->mutable_attr())["T"].set_type(output_type);
*added_node->add_input() = input_name; *added_node->add_input() = input_name;
node_map->AddNode(added_node->name(), added_node); node_map->AddNode(added_node->name(), added_node);
node_map->AddOutput(node->name(), added_node->name()); node_map->AddOutput(node.name(), added_node->name());
} }
return AsControlDependency(*added_node); return AsControlDependency(*added_node);
} }
@ -321,6 +300,15 @@ bool ConstantFolding::IsReallyConstant(const NodeDef& node) const {
return feed_nodes_.find(node.name()) == feed_nodes_.end(); return feed_nodes_.find(node.name()) == feed_nodes_.end();
} }
// TODO(rmlarsen): Refactor to shared util.
bool ConstantFolding::GetTensorFromConstNode(const string& node_name_or_input,
Tensor* tensor) {
const NodeDef* node = node_map_->GetNode(node_name_or_input);
return node != nullptr && IsReallyConstant(*node) &&
CheckAttrExists(*node, "value").ok() &&
tensor->FromProto(node->attr().at("value").tensor());
}
// Materialize the shapes using constants whenever possible. // Materialize the shapes using constants whenever possible.
Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) { Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) {
// We may add some nodes to the graph to encode control dependencies and hold // We may add some nodes to the graph to encode control dependencies and hold
@ -868,6 +856,9 @@ bool ConstantFolding::IsFoldable(const NodeDef& node) const {
return false; return false;
} }
if (node.op() == "AccumulateNV2") {
return false;
}
// Skips ops that don't benefit from folding. // Skips ops that don't benefit from folding.
if (IsPlaceholder(node)) { if (IsPlaceholder(node)) {
return false; return false;
@ -1856,9 +1847,9 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node,
SET_AND_RETURN_IF_MODIFIED( SET_AND_RETURN_IF_MODIFIED(
PartialAssocOpConstFolding(optimized_graph, properties, node)); PartialAssocOpConstFolding(optimized_graph, properties, node));
SET_AND_RETURN_IF_MODIFIED( SET_AND_RETURN_IF_MODIFIED(
PartialConcatConstFolding(optimized_graph, properties, node)); MergeConcat(use_shape_info, optimized_graph, node));
SET_AND_RETURN_IF_MODIFIED( SET_AND_RETURN_IF_MODIFIED(
MergeConcat(*properties, use_shape_info, optimized_graph, node)); PartialConcatConstFolding(optimized_graph, properties, node));
graph_modified_ = graph_modified_cached; graph_modified_ = graph_modified_cached;
return Status::OK(); return Status::OK();
@ -1879,26 +1870,18 @@ void ConstantFolding::RemoveSplitOrSplitV(const GraphProperties& properties,
Status ConstantFolding::RemoveShuffleOrTranspose( Status ConstantFolding::RemoveShuffleOrTranspose(
const GraphProperties& properties, bool use_shape_info, const GraphProperties& properties, bool use_shape_info,
GraphDef* optimized_graph, NodeDef* node) { GraphDef* optimized_graph, NodeDef* node) {
if (use_shape_info && (IsShuffle(*node) || IsTranspose(*node)) && if (!use_shape_info || !(IsShuffle(*node) || IsTranspose(*node)))
properties.GetInputProperties(node->name()).size() >= 2) {
const auto& shape = properties.GetInputProperties(node->name())[0].shape();
if (shape.unknown_rank()) {
// Not optimizable.
return Status::OK(); return Status::OK();
} Tensor permutation_tensor;
const auto& p = properties.GetInputProperties(node->name())[1]; if (GetTensorFromConstNode(node->input(1), &permutation_tensor) &&
if (TensorShape::IsValid(p.shape()) && p.has_value()) { properties.HasInputProperties(node->name())) {
Tensor perm(p.dtype(), p.shape()); const auto& shape = properties.GetInputProperties(node->name())[0].shape();
if (!perm.FromProto(p.value())) {
return errors::InvalidArgument("Cannot parse tensor from proto: ",
p.value().DebugString());
}
std::vector<int> permutation; std::vector<int> permutation;
for (int j = 0; j < perm.NumElements(); ++j) { for (int j = 0; j < permutation_tensor.NumElements(); ++j) {
if (perm.dtype() == DT_INT64) { if (permutation_tensor.dtype() == DT_INT64) {
permutation.push_back(perm.vec<int64>()(j)); permutation.push_back(permutation_tensor.vec<int64>()(j));
} else { } else {
permutation.push_back(perm.vec<int>()(j)); permutation.push_back(permutation_tensor.vec<int>()(j));
} }
} }
if (permutation.size() != shape.dim_size()) { if (permutation.size() != shape.dim_size()) {
@ -1914,8 +1897,6 @@ Status ConstantFolding::RemoveShuffleOrTranspose(
} }
if (replaceable) { if (replaceable) {
ReplaceOperationWithIdentity(0, properties, node, optimized_graph); ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
return Status::OK();
}
} }
} }
return Status::OK(); return Status::OK();
@ -1941,20 +1922,12 @@ Status ConstantFolding::RemoveReverse(const GraphProperties& properties,
bool use_shape_info, bool use_shape_info,
GraphDef* optimized_graph, GraphDef* optimized_graph,
NodeDef* node) { NodeDef* node) {
if (use_shape_info && node->op() == "ReverseV2" && if (!use_shape_info || node->op() != "ReverseV2") return Status::OK();
properties.GetInputProperties(node->name()).size() >= 2) { Tensor axis;
if (properties.HasInputProperties(node->name()) &&
GetTensorFromConstNode(node->input(1), &axis)) {
const auto& shape = properties.GetInputProperties(node->name())[0].shape(); const auto& shape = properties.GetInputProperties(node->name())[0].shape();
if (shape.unknown_rank()) { if (shape.unknown_rank()) return Status::OK();
// Not optimizable.
return Status::OK();
}
const auto& a = properties.GetInputProperties(node->name())[1];
if (TensorShape::IsValid(a.shape()) && a.has_value()) {
Tensor axis(a.dtype(), a.shape());
if (!axis.FromProto(a.value())) {
return errors::InvalidArgument("Cannot parse tensor from proto: ",
a.value().DebugString());
}
std::set<int> target_axes; std::set<int> target_axes;
for (int j = 0; j < axis.NumElements(); ++j) { for (int j = 0; j < axis.NumElements(); ++j) {
// value of axis can be negative. // value of axis can be negative.
@ -1971,16 +1944,15 @@ Status ConstantFolding::RemoveReverse(const GraphProperties& properties,
// unknown_rank == false && // unknown_rank == false &&
// (dim_size == 0 || all dims have size 1 || // (dim_size == 0 || all dims have size 1 ||
// all dims with > 1 size are not in target_axes) // all dims with > 1 size are not in target_axes)
bool replaceable = !shape.unknown_rank(); bool replaceable = true;
for (int j = 0; replaceable && j < shape.dim_size(); ++j) { for (int j = 0; replaceable && j < shape.dim_size(); ++j) {
replaceable &= shape.dim(j).size() == 1 || replaceable &=
target_axes.find(j) == target_axes.end(); shape.dim(j).size() == 1 || target_axes.find(j) == target_axes.end();
} }
if (replaceable) { if (replaceable) {
ReplaceOperationWithIdentity(0, properties, node, optimized_graph); ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
} }
} }
}
return Status::OK(); return Status::OK();
} }
@ -1988,23 +1960,13 @@ Status ConstantFolding::SimplifySlice(const GraphProperties& properties,
bool use_shape_info, bool use_shape_info,
GraphDef* optimized_graph, GraphDef* optimized_graph,
NodeDef* node) { NodeDef* node) {
if (use_shape_info && IsSlice(*node) && if (!use_shape_info || !IsSlice(*node)) return Status::OK();
properties.GetInputProperties(node->name()).size() == 3) { Tensor begin;
Tensor size;
if (properties.HasInputProperties(node->name()) &&
GetTensorFromConstNode(node->input(1), &begin) &&
GetTensorFromConstNode(node->input(2), &size)) {
const auto& input = properties.GetInputProperties(node->name())[0]; const auto& input = properties.GetInputProperties(node->name())[0];
const auto& b = properties.GetInputProperties(node->name())[1];
const auto& s = properties.GetInputProperties(node->name())[2];
if (TensorShape::IsValid(b.shape()) && b.has_value() &&
TensorShape::IsValid(s.shape()) && s.has_value()) {
Tensor begin(b.dtype(), b.shape());
if (!begin.FromProto(b.value())) {
return errors::InvalidArgument("Cannot parse tensor from proto: ",
b.value().DebugString());
}
Tensor size(s.dtype(), s.shape());
if (!size.FromProto(s.value())) {
return errors::InvalidArgument("Cannot parse tensor from proto: ",
s.value().DebugString());
}
// The node is replaceable iff unknown_rank == false && // The node is replaceable iff unknown_rank == false &&
// begin == 0 && (size == -1 || size == input_shape) for all dimensions // begin == 0 && (size == -1 || size == input_shape) for all dimensions
bool replaceable = !input.shape().unknown_rank(); bool replaceable = !input.shape().unknown_rank();
@ -2024,8 +1986,6 @@ Status ConstantFolding::SimplifySlice(const GraphProperties& properties,
} }
if (replaceable) { if (replaceable) {
ReplaceOperationWithIdentity(0, properties, node, optimized_graph); ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
return Status::OK();
}
} }
} }
return Status::OK(); return Status::OK();
@ -2052,27 +2012,18 @@ Status ConstantFolding::SimplifyStridedSlice(const GraphProperties& properties,
return Status::OK(); return Status::OK();
} }
} }
const auto& b = properties.GetInputProperties(node->name())[1];
const auto& e = properties.GetInputProperties(node->name())[2]; std::vector<Tensor> input_tensors(3);
const auto& s = properties.GetInputProperties(node->name())[3]; for (int i = 1; i < 4; ++i) {
if (TensorShape::IsValid(b.shape()) && b.has_value() && if (!GetTensorFromConstNode(node->input(i), &input_tensors[i - 1])) {
TensorShape::IsValid(e.shape()) && e.has_value() && return Status::OK();
TensorShape::IsValid(s.shape()) && s.has_value()) {
Tensor begin(b.dtype(), b.shape());
if (!begin.FromProto(b.value())) {
return errors::InvalidArgument("Cannot parse tensor from proto: ",
b.value().DebugString());
} }
Tensor end(e.dtype(), e.shape());
if (!end.FromProto(e.value())) {
return errors::InvalidArgument("Cannot parse tensor from proto: ",
e.value().DebugString());
}
Tensor strides(s.dtype(), s.shape());
if (!strides.FromProto(s.value())) {
return errors::InvalidArgument("Cannot parse tensor from proto: ",
s.value().DebugString());
} }
const Tensor& begin = input_tensors[0];
const Tensor& end = input_tensors[1];
const Tensor& strides = input_tensors[2];
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
CheckAttrsExist(*node, {"begin_mask", "end_mask", "ellipsis_mask"})); CheckAttrsExist(*node, {"begin_mask", "end_mask", "ellipsis_mask"}));
int begin_mask = node->attr().at("begin_mask").i(); int begin_mask = node->attr().at("begin_mask").i();
@ -2116,34 +2067,26 @@ Status ConstantFolding::SimplifyStridedSlice(const GraphProperties& properties,
} }
int b = begin.dtype() == DT_INT32 ? begin.vec<int>()(i) int b = begin.dtype() == DT_INT32 ? begin.vec<int>()(i)
: begin.vec<int64>()(i); : begin.vec<int64>()(i);
int e = int e = end.dtype() == DT_INT32 ? end.vec<int>()(i) : end.vec<int64>()(i);
end.dtype() == DT_INT32 ? end.vec<int>()(i) : end.vec<int64>()(i);
int s = strides.dtype() == DT_INT32 ? strides.vec<int>()(i) int s = strides.dtype() == DT_INT32 ? strides.vec<int>()(i)
: strides.vec<int64>()(i); : strides.vec<int64>()(i);
replaceable &= replaceable &= (begin_mask & 1 << i || b == 0) &&
(begin_mask & 1 << i || b == 0) && (end_mask & 1 << i || e == input.shape().dim(j).size()) &&
(end_mask & 1 << i || e == input.shape().dim(j).size()) && s == 1; s == 1;
} }
if (replaceable) { if (replaceable) {
ReplaceOperationWithIdentity(0, properties, node, optimized_graph); ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
} }
} }
}
return Status::OK(); return Status::OK();
} }
Status ConstantFolding::SimplifyTile(const GraphProperties& properties, Status ConstantFolding::SimplifyTile(const GraphProperties& properties,
bool use_shape_info, bool use_shape_info,
GraphDef* optimized_graph, NodeDef* node) { GraphDef* optimized_graph, NodeDef* node) {
Tensor multiplies;
if (use_shape_info && IsTile(*node) && if (use_shape_info && IsTile(*node) &&
properties.GetInputProperties(node->name()).size() == 2) { GetTensorFromConstNode(node->input(1), &multiplies)) {
const auto& m = properties.GetInputProperties(node->name())[1];
if (TensorShape::IsValid(m.shape()) && m.has_value()) {
Tensor multiplies(m.dtype(), m.shape());
if (!multiplies.FromProto(m.value())) {
return errors::InvalidArgument("Cannot parse tensor from proto: ",
m.value().DebugString());
}
// The node is replaceable iff all values in multiplies are 1. // The node is replaceable iff all values in multiplies are 1.
bool replaceable = true; bool replaceable = true;
if (multiplies.dtype() == DT_INT32) { if (multiplies.dtype() == DT_INT32) {
@ -2151,8 +2094,7 @@ Status ConstantFolding::SimplifyTile(const GraphProperties& properties,
replaceable &= multiplies.vec<int>()(j) == 1; replaceable &= multiplies.vec<int>()(j) == 1;
} }
} else { } else {
for (int j = 0; replaceable && j < multiplies.vec<int64>().size(); for (int j = 0; replaceable && j < multiplies.vec<int64>().size(); ++j) {
++j) {
replaceable &= multiplies.vec<int64>()(j) == 1; replaceable &= multiplies.vec<int64>()(j) == 1;
} }
} }
@ -2160,22 +2102,16 @@ Status ConstantFolding::SimplifyTile(const GraphProperties& properties,
ReplaceOperationWithIdentity(0, properties, node, optimized_graph); ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
} }
} }
}
return Status::OK(); return Status::OK();
} }
Status ConstantFolding::SimplifyPad(const GraphProperties& properties, Status ConstantFolding::SimplifyPad(const GraphProperties& properties,
bool use_shape_info, bool use_shape_info,
GraphDef* optimized_graph, NodeDef* node) { GraphDef* optimized_graph, NodeDef* node) {
if (use_shape_info && IsPad(*node) && if (!use_shape_info || !IsPad(*node)) return Status::OK();
properties.GetInputProperties(node->name()).size() >= 2) {
const auto& p = properties.GetInputProperties(node->name())[1]; Tensor paddings;
if (TensorShape::IsValid(p.shape()) && p.has_value()) { if (GetTensorFromConstNode(node->input(1), &paddings)) {
Tensor paddings(p.dtype(), p.shape());
if (!paddings.FromProto(p.value())) {
return errors::InvalidArgument("Cannot parse tensor from proto: ",
p.value().DebugString());
}
// The node is replaceable iff all values in paddings are 0. // The node is replaceable iff all values in paddings are 0.
bool replaceable = true; bool replaceable = true;
// The operation requires it to be int32 value so we don't check for // The operation requires it to be int32 value so we don't check for
@ -2188,7 +2124,6 @@ Status ConstantFolding::SimplifyPad(const GraphProperties& properties,
ReplaceOperationWithIdentity(0, properties, node, optimized_graph); ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
} }
} }
}
return Status::OK(); return Status::OK();
} }
@ -3031,16 +2966,17 @@ bool ConstantFolding::PartialAssocOpConstFolding(GraphDef* optimized_graph,
// folding of ops when more than one but not all inputs are constant. // folding of ops when more than one but not all inputs are constant.
// For AddN and AccumulateNV2, we may furthermore reorder inputs, since // For AddN and AccumulateNV2, we may furthermore reorder inputs, since
// addition is commutative. // addition is commutative.
if (!IsAggregate(*node) || !IsCommutative(*node)) return false;
const int num_non_control_inputs = NumNonControlInputs(*node); const int num_non_control_inputs = NumNonControlInputs(*node);
if (IsAggregate(*node) && IsCommutative(*node) && if (num_non_control_inputs <= 2) return false;
num_non_control_inputs > 2) {
const int num_control_inputs = node->input_size() - num_non_control_inputs; const int num_control_inputs = node->input_size() - num_non_control_inputs;
std::vector<int> const_inputs; std::vector<int> const_inputs;
std::vector<int> nonconst_inputs; std::vector<int> nonconst_inputs;
for (int i = 0; i < node->input_size(); ++i) { for (int i = 0; i < node->input_size(); ++i) {
const string& input = node->input(i); const string& input = node->input(i);
const NodeDef* input_node = node_map_->GetNode(NodeName(input)); const NodeDef* input_node = node_map_->GetNode(NodeName(input));
CHECK(input_node != nullptr) << input; if (input_node == nullptr) return false;
if (!IsControlInput(input) && IsReallyConstant(*input_node)) { if (!IsControlInput(input) && IsReallyConstant(*input_node)) {
const_inputs.push_back(i); const_inputs.push_back(i);
} else { } else {
@ -3058,8 +2994,7 @@ bool ConstantFolding::PartialAssocOpConstFolding(GraphDef* optimized_graph,
} }
const string new_node_name = OptimizedNodeName( const string new_node_name = OptimizedNodeName(
*node, strings::StrCat("_partial_split_", const_inputs.size())); *node, strings::StrCat("_partial_split_", const_inputs.size()));
if (1 < const_inputs.size() && if (const_inputs.size() > 1 && const_inputs.size() < num_non_control_inputs &&
const_inputs.size() < num_non_control_inputs &&
!node_map_->NodeExists(new_node_name)) { !node_map_->NodeExists(new_node_name)) {
NodeDef* added_node = optimized_graph->add_node(); NodeDef* added_node = optimized_graph->add_node();
*added_node = *node; *added_node = *node;
@ -3091,13 +3026,11 @@ bool ConstantFolding::PartialAssocOpConstFolding(GraphDef* optimized_graph,
} }
node->mutable_input()->DeleteSubrange(nonconst_inputs.size(), node->mutable_input()->DeleteSubrange(nonconst_inputs.size(),
const_inputs.size() - 1); const_inputs.size() - 1);
(*node->mutable_attr())["N"].set_i(node->input_size() - (*node->mutable_attr())["N"].set_i(node->input_size() - num_control_inputs);
num_control_inputs);
properties->ClearInputProperties(node->name()); properties->ClearInputProperties(node->name());
(*added_node->mutable_attr())["N"].set_i(const_inputs.size()); (*added_node->mutable_attr())["N"].set_i(const_inputs.size());
return true; return true;
} }
}
return false; return false;
} }
@ -3107,9 +3040,12 @@ bool ConstantFolding::PartialConcatConstFolding(GraphDef* optimized_graph,
// Partial constant folding for Concat which is not commutative, so // Partial constant folding for Concat which is not commutative, so
// we have to preserve order and can only push consecutive runs of constant // we have to preserve order and can only push consecutive runs of constant
// inputs into sub-nodes. // inputs into sub-nodes.
if (!IsConcat(*node) ||
node->name().rfind("_partial_split_") != string::npos) {
return false;
}
const int num_non_control_inputs = NumNonControlInputs(*node); const int num_non_control_inputs = NumNonControlInputs(*node);
if (IsConcat(*node) && num_non_control_inputs > 3 && if (num_non_control_inputs <= 3) return false;
node->name().rfind("_partial_split_") == string::npos) {
int axis_arg = -1; int axis_arg = -1;
int begin = 0; int begin = 0;
int end = num_non_control_inputs; int end = num_non_control_inputs;
@ -3123,14 +3059,6 @@ bool ConstantFolding::PartialConcatConstFolding(GraphDef* optimized_graph,
return false; return false;
} }
const NodeDef* axis_arg_node =
node_map_->GetNode(NodeName(node->input(axis_arg)));
if (axis_arg_node == nullptr || !IsReallyConstant(*axis_arg_node)) {
// We cannot constant fold Concat unless we the axis argument is
// constant. Skip node.
return false;
}
// We search for consecutive runs of constant inputs in the range // We search for consecutive runs of constant inputs in the range
// [begin:end[ and push then down into child nodes. // [begin:end[ and push then down into child nodes.
std::vector<std::pair<int, int>> constant_input_runs; std::vector<std::pair<int, int>> constant_input_runs;
@ -3143,8 +3071,8 @@ bool ConstantFolding::PartialConcatConstFolding(GraphDef* optimized_graph,
} }
// Invariant: node[first] is constant || first >= end. // Invariant: node[first] is constant || first >= end.
last = first + 1; last = first + 1;
while (last < end && IsReallyConstant(*node_map_->GetNode( while (last < end &&
NodeName(node->input(last))))) { IsReallyConstant(*node_map_->GetNode(NodeName(node->input(last))))) {
++last; ++last;
} }
// Invariant: node[last] is not constant || last >= end // Invariant: node[last] is not constant || last >= end
@ -3156,9 +3084,9 @@ bool ConstantFolding::PartialConcatConstFolding(GraphDef* optimized_graph,
} }
// Skip if all inputs are constant, and let constant folding take over. // Skip if all inputs are constant, and let constant folding take over.
if (constant_input_runs.size() == 1 && if (constant_input_runs.empty() || (constant_input_runs.size() == 1 &&
constant_input_runs[0].first == begin && constant_input_runs[0].first == begin &&
constant_input_runs[0].second == end) { constant_input_runs[0].second == end)) {
return false; return false;
} }
std::set<int> inputs_to_delete; std::set<int> inputs_to_delete;
@ -3172,35 +3100,32 @@ bool ConstantFolding::PartialConcatConstFolding(GraphDef* optimized_graph,
NodeDef* added_node = optimized_graph->add_node(); NodeDef* added_node = optimized_graph->add_node();
*added_node = *node; *added_node = *node;
added_node->set_op("ConcatV2");
added_node->set_name(new_node_name); added_node->set_name(new_node_name);
node_map_->AddNode(added_node->name(), added_node); node_map_->AddNode(added_node->name(), added_node);
added_node->clear_input(); added_node->clear_input();
for (int i = interval.first; i < interval.second; ++i) { for (int i = interval.first; i < interval.second; ++i) {
added_node->add_input(node->input(i)); added_node->add_input(node->input(i));
node_map_->UpdateOutput(NodeName(node->input(i)), node->name(), node_map_->UpdateInput(node->name(), node->input(i), added_node->name());
added_node->name());
if (i != interval.first) { if (i != interval.first) {
inputs_to_delete.insert(i); inputs_to_delete.insert(i);
} }
} }
added_node->add_input(node->input(axis_arg)); added_node->add_input(node->input(axis_arg));
(*added_node->mutable_attr())["N"].set_i(interval.second - (*added_node->mutable_attr())["N"].set_i(interval.second - interval.first);
interval.first);
node_map_->AddOutput(NodeName(node->input(axis_arg)), added_node->name()); node_map_->AddOutput(NodeName(node->input(axis_arg)), added_node->name());
// Overwrite the first constant input with the result of the added // Overwrite the first constant input with the result of the added
// child node. // child node.
node->set_input(interval.first, added_node->name()); node->set_input(interval.first, added_node->name());
node_map_->AddOutput(added_node->name(), node->name());
} }
if (!constant_input_runs.empty()) { if (!constant_input_runs.empty() && !inputs_to_delete.empty()) {
if (!inputs_to_delete.empty()) {
// Fix up the inputs to the original node. // Fix up the inputs to the original node.
std::vector<string> tmp(node->input().begin(), node->input().end()); protobuf::RepeatedPtrField<string> tmp;
node->clear_input(); tmp.Swap(node->mutable_input());
for (int i = 0; i < tmp.size(); ++i) { for (int i = 0; i < tmp.size(); ++i) {
if (inputs_to_delete.find(i) == inputs_to_delete.end()) { if (inputs_to_delete.find(i) == inputs_to_delete.end()) {
node->add_input(tmp[i]); node->add_input(tmp.Get(i));
} }
} }
(*node->mutable_attr())["N"].set_i(node->input_size() - 1); (*node->mutable_attr())["N"].set_i(node->input_size() - 1);
@ -3208,55 +3133,83 @@ bool ConstantFolding::PartialConcatConstFolding(GraphDef* optimized_graph,
} }
return true; return true;
} }
}
bool ConstantFolding::GetConcatAxis(const NodeDef& node, int* axis) {
if (node.op() != "ConcatV2") {
return false; return false;
} }
int axis_idx = node.input_size() - 1;
while (axis_idx > 0 && IsControlInput(node.input(axis_idx))) {
--axis_idx;
}
if (axis_idx <= 0) {
return false;
}
Tensor axis_tensor;
if (!GetTensorFromConstNode(node.input(axis_idx), &axis_tensor)) {
return false;
}
*axis = axis_tensor.dtype() == DT_INT64
? static_cast<int>(axis_tensor.scalar<int64>()())
: axis_tensor.scalar<int32>()();
return true;
}
bool ConstantFolding::MergeConcat(const GraphProperties& properties, bool ConstantFolding::MergeConcat(bool use_shape_info,
bool use_shape_info,
GraphDef* optimized_graph, NodeDef* node) { GraphDef* optimized_graph, NodeDef* node) {
// We only optimize for ConcatV2. // We only optimize for ConcatV2.
int axis; int axis;
if (!use_shape_info || !GetConcatAxis(properties, node, &axis) || if (!use_shape_info || !GetConcatAxis(*node, &axis) ||
nodes_to_preserve_.find(node->name()) != nodes_to_preserve_.end() || nodes_to_preserve_.find(node->name()) != nodes_to_preserve_.end() ||
node_map_->GetOutputs(node->name()).size() != 1) { node_map_->GetOutputs(node->name()).size() != 1) {
return false; return false;
} }
// If all inputs are constant, don't merge and let folding take case of it.
const int num_regular_inputs = NumNonControlInputs(*node);
bool all_inputs_are_const = true;
for (int i = 0; i < num_regular_inputs - 1; ++i) {
const NodeDef* input_node = node_map_->GetNode(node->input(i));
if (!IsReallyConstant(*input_node)) {
all_inputs_are_const = false;
}
}
if (all_inputs_are_const) return false;
NodeDef* parent = *node_map_->GetOutputs(node->name()).begin(); NodeDef* parent = *node_map_->GetOutputs(node->name()).begin();
int parent_axis; int parent_axis;
if (!GetConcatAxis(properties, parent, &parent_axis) || axis != parent_axis) { if (!GetConcatAxis(*parent, &parent_axis) || axis != parent_axis) {
return false; return false;
} }
const int index = NumNonControlInputs(*node) - 1; protobuf::RepeatedPtrField<string> parent_inputs;
auto inputs = parent->input(); parent_inputs.Swap(parent->mutable_input());
parent->clear_input(); std::vector<string> ctrl_output;
for (int i = 0; i < inputs.size(); ++i) { // TODO(rmlarsen): IF the child occurs more than once, is it beneficial to
if (IsSameInput(inputs.Get(i), node->name())) { // collapse it into the parent multiple times? Probablyu not.
for (int j = 0; j < node->input_size(); ++j) { for (const auto& input : parent_inputs) {
if (j < index) { if (IsSameInput(input, node->name())) {
// Input tensors (non axis), add to input list of parent. for (int j = 0; j < num_regular_inputs - 1; ++j) {
// Add tensor inputs to first child concat tensors (exceptthe final axis
// input) to the parent's inputs.
parent->add_input(node->input(j)); parent->add_input(node->input(j));
node_map_->RemoveOutput(node->input(j), node->name()); node_map_->UpdateInput(parent->name(), node->name(), node->input(j));
node_map_->AddOutput(node->input(j), parent->name());
}
// Skip j == index, which means axis tensor.
if (j > index) {
// Control Dependencies, push back to inputs so they can be forwarded
// to parent.
*inputs.Add() = node->input(j);
}
} }
} else { } else {
parent->add_input(inputs.Get(i)); parent->add_input(input);
} }
} }
// Forward Add control inputs
for (int i = num_regular_inputs; i < node->input_size(); ++i) {
parent->add_input(node->input(i));
node_map_->UpdateInput(parent->name(), node->name(), node->input(i));
}
node->clear_input(); node->clear_input();
node->set_op("NoOp"); node->set_op("NoOp");
node->clear_attr(); node->clear_attr();
node_map_->RemoveNode(node->name()); node_map_->RemoveNode(node->name());
(*parent->mutable_attr())["N"].set_i(NumNonControlInputs(*parent) - 1); (*parent->mutable_attr())["N"].set_i(NumNonControlInputs(*parent) - 1);
DedupControlInputs(parent);
return true; return true;
} }
@ -3344,7 +3297,9 @@ Status ConstantFolding::RunOptimizationPass(Cluster* cluster,
// that the shape inference deals with this conservatively unless we're in // that the shape inference deals with this conservatively unless we're in
// aggressive mode. // aggressive mode.
const bool assume_valid_feeds = opt_level_ == RewriterConfig::AGGRESSIVE; const bool assume_valid_feeds = opt_level_ == RewriterConfig::AGGRESSIVE;
Status s = properties.InferStatically(assume_valid_feeds); Status s = properties.InferStatically(assume_valid_feeds,
/*aggressive_shape_inference=*/false,
/*include_tensor_values=*/false);
const bool can_use_shape_info = s.ok(); const bool can_use_shape_info = s.ok();
if (can_use_shape_info) { if (can_use_shape_info) {

View File

@ -61,6 +61,8 @@ class ConstantFolding : public GraphOptimizer {
bool IsReallyConstant(const NodeDef& node) const; bool IsReallyConstant(const NodeDef& node) const;
bool GetTensorFromConstNode(const string& node_name_or_input, Tensor* tensor);
Status MaterializeShapes(const GraphProperties& properties); Status MaterializeShapes(const GraphProperties& properties);
Status MaterializeBroadcastGradientArgs(const NodeDef& node, Status MaterializeBroadcastGradientArgs(const NodeDef& node,
@ -239,8 +241,9 @@ class ConstantFolding : public GraphOptimizer {
void RemoveSplitOrSplitV(const GraphProperties& properties, void RemoveSplitOrSplitV(const GraphProperties& properties,
GraphDef* optimized_graph, NodeDef* node); GraphDef* optimized_graph, NodeDef* node);
bool MergeConcat(const GraphProperties& properties, bool use_shape_info, bool GetConcatAxis(const NodeDef& node, int* axis);
GraphDef* optimized_graph, NodeDef* node); bool MergeConcat(bool use_shape_info, GraphDef* optimized_graph,
NodeDef* node);
Status AddQuantizedMatMulMinMaxOutConstNodes(NodeDef* node, Status AddQuantizedMatMulMinMaxOutConstNodes(NodeDef* node,
GraphDef* optimized_graph); GraphDef* optimized_graph);

View File

@ -2390,12 +2390,11 @@ TEST_F(ConstantFoldingTest, MergeConcat_PartialFolding) {
TF_EXPECT_OK(status); TF_EXPECT_OK(status);
GraphDef want; GraphDef want;
AddNode("ConstantFolding/concat2_partial_split_0_0", "Const", {}, {}, &want); AddNode("ConstantFolding/concat2_partial_split_0", "Const", {}, {}, &want);
AddNode("axis", "Const", {}, {}, &want); AddNode("axis", "Const", {}, {}, &want);
AddNode("ph", "Placeholder", {}, {}, &want); AddNode("ph", "Placeholder", {}, {}, &want);
AddNode("concat2", "ConcatV2", AddNode("concat2", "ConcatV2",
{"ConstantFolding/concat2_partial_split_0_0", "ph", "axis"}, {}, {"ConstantFolding/concat2_partial_split_0", "ph", "axis"}, {}, &want);
&want);
CompareGraphs(want, got); CompareGraphs(want, got);
} }

View File

@ -2209,7 +2209,10 @@ Status LayoutOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
} }
GraphProperties graph_properties(item); GraphProperties graph_properties(item);
TF_RETURN_IF_ERROR(graph_properties.InferStatically(false)); TF_RETURN_IF_ERROR(
graph_properties.InferStatically(/*assume_valid_feeds=*/false,
/*aggressive_shape_inference=*/false,
/*include_tensor_values=*/false));
GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED(); GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
virtual_placer_.reset(new VirtualPlacer(cluster->GetDevices())); virtual_placer_.reset(new VirtualPlacer(cluster->GetDevices()));

View File

@ -102,7 +102,8 @@ Status IsNodeOutputPortHostFriendly(const GraphView& graph,
if (!properties->has_properties()) { if (!properties->has_properties()) {
// This is an expensive call, call it lazily. // This is an expensive call, call it lazily.
TF_RETURN_IF_ERROR(properties->InferStatically( TF_RETURN_IF_ERROR(properties->InferStatically(
/*assume_valid_feeds=*/false)); /*assume_valid_feeds=*/false, /*aggressive_shape_inference=*/false,
/*include_tensor_values=*/false));
} }
const auto& output_properties = properties->GetOutputProperties(node.name()); const auto& output_properties = properties->GetOutputProperties(node.name());
if (port_id >= output_properties.size()) { if (port_id >= output_properties.size()) {
@ -252,7 +253,8 @@ Status IsNodeHostCandidate(const GraphView& graph, GraphProperties* properties,
if (!properties->has_properties()) { if (!properties->has_properties()) {
// This is an expensive call, call it lazily. // This is an expensive call, call it lazily.
TF_RETURN_IF_ERROR(properties->InferStatically( TF_RETURN_IF_ERROR(properties->InferStatically(
/*assume_valid_feeds=*/false)); /*assume_valid_feeds=*/false, /*aggressive_shape_inference=*/false,
/*include_tensor_values=*/false));
} }
for (const auto& prop : properties->GetOutputProperties(node.name())) { for (const auto& prop : properties->GetOutputProperties(node.name())) {
if (!IsTensorSmall(prop)) { if (!IsTensorSmall(prop)) {

View File

@ -1170,7 +1170,11 @@ Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item,
// Infer properties lazily in case they are not needed. // Infer properties lazily in case they are not needed.
if (!ctx.inferred_graph_properties && IsFusedBatchNormCandidate(node)) { if (!ctx.inferred_graph_properties && IsFusedBatchNormCandidate(node)) {
TF_RETURN_IF_ERROR(ctx.graph_properties.InferStatically(false)); // TODO(rmlarsen): Get rid of tensor value copies.
TF_RETURN_IF_ERROR(ctx.graph_properties.InferStatically(
/*assume_valid_feeds=*/false,
/*aggressive_shape_inference=*/false,
/*include_tensor_values=*/true));
ctx.inferred_graph_properties = true; ctx.inferred_graph_properties = true;
} }

View File

@ -739,9 +739,9 @@ Status ScopedAllocatorOptimizer::Optimize(Cluster* /*cluster*/,
GraphProperties graph_properties(item); GraphProperties graph_properties(item);
const bool assume_valid_feeds = opt_level_ == RewriterConfig::AGGRESSIVE; const bool assume_valid_feeds = opt_level_ == RewriterConfig::AGGRESSIVE;
LOG_WARNING_AND_RETURN_IF_ERROR( LOG_WARNING_AND_RETURN_IF_ERROR(graph_properties.InferStatically(
graph_properties.InferStatically(assume_valid_feeds)); assume_valid_feeds, /*aggressive_shape_inference=*/false,
/*include_tensor_values=*/false));
*optimized_graph = item.graph; *optimized_graph = item.graph;
node_map_.reset(new NodeMap(optimized_graph)); node_map_.reset(new NodeMap(optimized_graph));

View File

@ -87,7 +87,10 @@ Status ShapeOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
graph.GetRegularFanin(MutableGraphView::InputPort(fanout.node, 1)); graph.GetRegularFanin(MutableGraphView::InputPort(fanout.node, 1));
if (!inferred_properties) { if (!inferred_properties) {
// Infer properties lazily in case they are not needed. // Infer properties lazily in case they are not needed.
TF_RETURN_IF_ERROR(properties.InferStatically(false)); TF_RETURN_IF_ERROR(
properties.InferStatically(/*assume_valid_feeds=*/false,
/*aggressive_shape_inference=*/false,
/*include_tensor_values=*/false));
inferred_properties = true; inferred_properties = true;
} }
const auto& prop = const auto& prop =
@ -144,7 +147,10 @@ Status ShapeOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
} }
if (!inferred_properties) { if (!inferred_properties) {
// Infer properties lazily in case they are not needed. // Infer properties lazily in case they are not needed.
TF_RETURN_IF_ERROR(properties.InferStatically(false)); TF_RETURN_IF_ERROR(
properties.InferStatically(/*assume_valid_feeds=*/false,
/*aggressive_shape_inference=*/false,
/*include_tensor_values=*/false));
inferred_properties = true; inferred_properties = true;
} }
const auto& prop1 = properties.GetInputProperties(input1.node->name()); const auto& prop1 = properties.GetInputProperties(input1.node->name());

View File

@ -14,7 +14,9 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/core/grappler/optimizers/static_schedule.h" #include "tensorflow/core/grappler/optimizers/static_schedule.h"
#include <deque> #include <deque>
#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/grappler/costs/graph_properties.h" #include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/core/grappler/costs/op_level_cost_estimator.h" #include "tensorflow/core/grappler/costs/op_level_cost_estimator.h"
@ -92,7 +94,10 @@ Status EstimateEarliestExecutionTimes(
name_map.clear(); name_map.clear();
GraphProperties properties(item); GraphProperties properties(item);
TF_RETURN_IF_ERROR(properties.InferStatically(true)); TF_RETURN_IF_ERROR(
properties.InferStatically(/*assume_valid_feeds=*/true,
/*aggressive_shape_inference=*/false,
/*include_tensor_values=*/false));
OpLevelCostEstimator estimator; OpLevelCostEstimator estimator;
VirtualPlacer placer(cluster->GetDevices()); VirtualPlacer placer(cluster->GetDevices());
@ -160,7 +165,10 @@ Status EstimateRequiredTimes(
} }
} }
GraphProperties properties(item); GraphProperties properties(item);
TF_RETURN_IF_ERROR(properties.InferStatically(true)); TF_RETURN_IF_ERROR(
properties.InferStatically(/*assume_valid_feeds=*/true,
/*aggressive_shape_inference=*/false,
/*include_tensor_values=*/false));
OpLevelCostEstimator estimator; OpLevelCostEstimator estimator;
VirtualPlacer placer(cluster->GetDevices()); VirtualPlacer placer(cluster->GetDevices());

View File

@ -346,7 +346,8 @@ class FoldOldBatchNormsTest : public ::testing::Test {
std::vector<Tensor> fused_outputs; std::vector<Tensor> fused_outputs;
TF_ASSERT_OK(fused_session->Run({}, {"output"}, {}, &fused_outputs)); TF_ASSERT_OK(fused_session->Run({}, {"output"}, {}, &fused_outputs));
test::ExpectTensorNear<float>(original_outputs[0], fused_outputs[0], 2e-5); test::ExpectClose(original_outputs[0], fused_outputs[0], /*atol=*/2e-5,
/*rtol=*/2e-5);
for (const NodeDef& node : fused_graph_def.node()) { for (const NodeDef& node : fused_graph_def.node()) {
EXPECT_NE("FusedBatchNorm", node.op()); EXPECT_NE("FusedBatchNorm", node.op());